This commit is contained in:
@@ -36,6 +36,14 @@ static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
struct PromotedOperands {
|
||||
SmallVector<bool> promoteInput;
|
||||
SmallVector<Value> newWeights;
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
};
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
@@ -48,60 +56,91 @@ static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
|
||||
PromotedOperands promoted;
|
||||
promoted.promoteInput.assign(compute.getInputs().size(), false);
|
||||
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
|
||||
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
promoted.newInputs.reserve(compute.getInputs().size());
|
||||
promoted.newInputTypes.reserve(compute.getInputs().size());
|
||||
promoted.newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
goto keep_input;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
goto keep_input;
|
||||
promoted.promoteInput[inputIdx] = true;
|
||||
promoted.newWeights.push_back(input);
|
||||
needsRewrite = true;
|
||||
continue;
|
||||
|
||||
keep_input:
|
||||
promoted.newInputs.push_back(input);
|
||||
promoted.newInputTypes.push_back(input.getType());
|
||||
promoted.newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
if (!needsRewrite)
|
||||
return failure();
|
||||
return promoted;
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
||||
const PromotedOperands& promoted,
|
||||
IRRewriter& bodyRewriter,
|
||||
IRMapping& mapper,
|
||||
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
|
||||
PatternRewriter& rewriter) {
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
|
||||
if (!promoted.promoteInput[oldInputIdx]) {
|
||||
auto newInputArg = getNewInputArg(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
auto promoted = computePromotedOperands(compute);
|
||||
if (failed(promoted))
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : newWeights) {
|
||||
for (Value weight : promoted->newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
@@ -115,24 +154,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
if (failed(mapPromotedInputArguments(
|
||||
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
||||
return failure();
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
@@ -156,63 +180,35 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
auto promoted = computePromotedOperands(compute);
|
||||
if (failed(promoted))
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
promoted->newWeights,
|
||||
promoted->newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
for (Value weight : promoted->newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
@@ -224,7 +220,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
@@ -242,29 +238,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
if (failed(mapPromotedInputArguments(
|
||||
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
||||
return failure();
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
mapper.map(*outputArg,
|
||||
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock)
|
||||
|
||||
Reference in New Issue
Block a user