This commit is contained in:
@@ -77,7 +77,7 @@ static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute)
|
||||
needsRewrite = true;
|
||||
continue;
|
||||
|
||||
keep_input:
|
||||
keep_input:
|
||||
promoted.newInputs.push_back(input);
|
||||
promoted.newInputTypes.push_back(input.getType());
|
||||
promoted.newInputLocs.push_back(input.getLoc());
|
||||
@@ -127,8 +127,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : promoted->newWeights) {
|
||||
@@ -155,7 +155,12 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
if (failed(mapPromotedInputArguments(
|
||||
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
||||
compute,
|
||||
*promoted,
|
||||
bodyRewriter,
|
||||
mapper,
|
||||
[&](size_t index) { return newCompute.getInputArgument(index); },
|
||||
rewriter)))
|
||||
return failure();
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
@@ -199,7 +204,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
|
||||
+ compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
@@ -239,7 +245,12 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
if (failed(mapPromotedInputArguments(
|
||||
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
||||
compute,
|
||||
*promoted,
|
||||
bodyRewriter,
|
||||
mapper,
|
||||
[&](size_t index) { return newCompute.getInputArgument(index); },
|
||||
rewriter)))
|
||||
return failure();
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
|
||||
Reference in New Issue
Block a user