This commit is contained in:
@@ -18,13 +18,17 @@ namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -93,14 +99,15 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
auto bodyResult = detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
@@ -123,6 +130,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -131,13 +140,13 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
|
||||
@@ -44,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||
if (!computes.empty() || !computeBatches.empty())
|
||||
return;
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
@@ -190,16 +191,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
|
||||
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
|
||||
@@ -402,24 +402,37 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
||||
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
||||
for (Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
for (Value input : aHSlices[coreId]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
Block* body =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(
|
||||
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp->getResult(0));
|
||||
}
|
||||
@@ -530,37 +543,47 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
TypeRange {outType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
ValueRange {b},
|
||||
ValueRange {a});
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
||||
SmallVector<Location> blockArgLocs(4, loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value lane = batchOp.getLaneArgument();
|
||||
Value weight = batchOp.getWeightArgument(0);
|
||||
Value packedInput = batchOp.getInputArgument(0);
|
||||
Value packedOutput = batchOp.getOutputArgument(0);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value row =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
.getResult();
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
|
||||
unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -35,58 +35,15 @@ template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= block.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
@@ -96,11 +53,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
@@ -131,8 +86,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -141,14 +104,17 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
@@ -180,11 +146,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
@@ -220,8 +184,25 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
|
||||
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -230,31 +211,28 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
||||
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
@@ -262,10 +240,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
@@ -277,8 +251,6 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
@@ -7,14 +7,10 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||
|
||||
Reference in New Issue
Block a user