This commit is contained in:
@@ -76,8 +76,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
continue;
|
||||
@@ -89,14 +88,13 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
continue;
|
||||
@@ -108,7 +106,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
@@ -254,7 +252,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
}
|
||||
}
|
||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
Value hostConstant = constantOp.getResult();
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
@@ -264,40 +262,22 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
else {
|
||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
||||
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user