289 lines
12 KiB
C++
289 lines
12 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
static bool isWeightMaterializationHelperUser(Operation* op) {
|
|
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
|
|
}
|
|
|
|
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
|
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
|
}
|
|
|
|
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
|
return arg && canPromoteInputBlockArgument(*arg);
|
|
}
|
|
|
|
static bool isDirectConstantValue(Value value) {
|
|
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
|
}
|
|
|
|
struct PromotedOperands {
|
|
SmallVector<bool> promoteInput;
|
|
SmallVector<Value> newWeights;
|
|
SmallVector<Value> newInputs;
|
|
SmallVector<Type> newInputTypes;
|
|
SmallVector<Location> newInputLocs;
|
|
};
|
|
|
|
template <typename ComputeOpTy>
|
|
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
if (!isWeightLikeComputeOperand(input))
|
|
continue;
|
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
|
continue;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
|
|
PromotedOperands promoted;
|
|
promoted.promoteInput.assign(compute.getInputs().size(), false);
|
|
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
|
|
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
|
promoted.newInputs.reserve(compute.getInputs().size());
|
|
promoted.newInputTypes.reserve(compute.getInputs().size());
|
|
promoted.newInputLocs.reserve(compute.getInputs().size());
|
|
|
|
bool needsRewrite = false;
|
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
if (!isWeightLikeComputeOperand(input))
|
|
goto keep_input;
|
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
|
goto keep_input;
|
|
promoted.promoteInput[inputIdx] = true;
|
|
promoted.newWeights.push_back(input);
|
|
needsRewrite = true;
|
|
continue;
|
|
|
|
keep_input:
|
|
promoted.newInputs.push_back(input);
|
|
promoted.newInputTypes.push_back(input.getType());
|
|
promoted.newInputLocs.push_back(input.getLoc());
|
|
}
|
|
|
|
if (!needsRewrite)
|
|
return failure();
|
|
return promoted;
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
|
const PromotedOperands& promoted,
|
|
IRRewriter& bodyRewriter,
|
|
IRMapping& mapper,
|
|
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
|
|
PatternRewriter& rewriter) {
|
|
size_t newInputIdx = 0;
|
|
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
auto oldArg = compute.getInputArgument(oldInputIdx);
|
|
if (!oldArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
|
|
if (!promoted.promoteInput[oldInputIdx]) {
|
|
auto newInputArg = getNewInputArg(newInputIdx++);
|
|
if (!newInputArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
|
|
mapper.map(*oldArg, *newInputArg);
|
|
continue;
|
|
}
|
|
|
|
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
|
if (failed(clonedValue))
|
|
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
|
mapper.map(*oldArg, *clonedValue);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
|
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
|
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
|
auto promoted = computePromotedOperands(compute);
|
|
if (failed(promoted))
|
|
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
|
Block& oldBlock = compute.getBody().front();
|
|
|
|
rewriter.setInsertionPointAfter(compute);
|
|
auto newCompute = spatial::SpatCompute::create(
|
|
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
|
SmallVector<Type> newBlockArgTypes;
|
|
SmallVector<Location> newBlockArgLocs;
|
|
for (Value weight : promoted->newWeights) {
|
|
newBlockArgTypes.push_back(weight.getType());
|
|
newBlockArgLocs.push_back(weight.getLoc());
|
|
}
|
|
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
|
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
|
auto* newBlock = rewriter.createBlock(
|
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
|
newCompute.getProperties().setOperandSegmentSizes(
|
|
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
|
rewriter.setInsertionPointToStart(newBlock);
|
|
|
|
IRRewriter bodyRewriter(rewriter.getContext());
|
|
bodyRewriter.setInsertionPointToStart(newBlock);
|
|
|
|
IRMapping mapper;
|
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
|
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
|
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
|
if (!oldWeightArg || !newWeightArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
|
mapper.map(*oldWeightArg, *newWeightArg);
|
|
}
|
|
if (failed(mapPromotedInputArguments(
|
|
compute,
|
|
*promoted,
|
|
bodyRewriter,
|
|
mapper,
|
|
[&](size_t index) { return newCompute.getInputArgument(index); },
|
|
rewriter)))
|
|
return failure();
|
|
|
|
for (Operation& op : oldBlock.without_terminator())
|
|
rewriter.clone(op, mapper);
|
|
|
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
SmallVector<Value> newYieldOperands;
|
|
newYieldOperands.reserve(oldYield.getOutputs().size());
|
|
for (Value operand : oldYield.getOutputs()) {
|
|
auto mapped = mapper.lookupOrNull(operand);
|
|
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
|
}
|
|
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
|
|
|
rewriter.replaceOp(compute, newCompute.getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
|
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
|
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
|
auto promoted = computePromotedOperands(compute);
|
|
if (failed(promoted))
|
|
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
|
Block& oldBlock = compute.getBody().front();
|
|
|
|
rewriter.setInsertionPointAfter(compute);
|
|
|
|
auto newCompute =
|
|
spatial::SpatComputeBatch::create(rewriter,
|
|
compute.getLoc(),
|
|
compute.getResultTypes(),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
|
promoted->newWeights,
|
|
promoted->newInputs);
|
|
auto laneArg = compute.getLaneArgument();
|
|
if (!laneArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
|
SmallVector<Type> newBlockArgTypes;
|
|
SmallVector<Location> newBlockArgLocs;
|
|
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
|
|
+ compute.getNumResults());
|
|
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
|
newBlockArgTypes.push_back(laneArg->getType());
|
|
newBlockArgLocs.push_back(laneArg->getLoc());
|
|
for (Value weight : promoted->newWeights) {
|
|
newBlockArgTypes.push_back(weight.getType());
|
|
newBlockArgLocs.push_back(weight.getLoc());
|
|
}
|
|
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
|
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
|
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
|
if (!outputArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
|
newBlockArgTypes.push_back(resultType);
|
|
newBlockArgLocs.push_back(outputArg->getLoc());
|
|
}
|
|
|
|
auto* newBlock = rewriter.createBlock(
|
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
|
newCompute.getProperties().setOperandSegmentSizes(
|
|
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
|
rewriter.setInsertionPointToStart(newBlock);
|
|
|
|
IRRewriter bodyRewriter(rewriter.getContext());
|
|
bodyRewriter.setInsertionPointToStart(newBlock);
|
|
|
|
IRMapping mapper;
|
|
auto newLaneArg = newCompute.getLaneArgument();
|
|
if (!newLaneArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
|
mapper.map(*laneArg, *newLaneArg);
|
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
|
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
|
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
|
if (!oldWeightArg || !newWeightArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
|
mapper.map(*oldWeightArg, *newWeightArg);
|
|
}
|
|
if (failed(mapPromotedInputArguments(
|
|
compute,
|
|
*promoted,
|
|
bodyRewriter,
|
|
mapper,
|
|
[&](size_t index) { return newCompute.getInputArgument(index); },
|
|
rewriter)))
|
|
return failure();
|
|
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
|
if (!outputArg)
|
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
|
mapper.map(*outputArg,
|
|
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
|
|
}
|
|
|
|
for (Operation& op : oldBlock)
|
|
rewriter.clone(op, mapper);
|
|
|
|
rewriter.replaceOp(compute, newCompute.getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
|
}
|
|
|
|
void annotateWeightsConstants(func::FuncOp funcOp) {
|
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
|
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
|
markWeightAlways(constantOp);
|
|
});
|
|
}
|
|
|
|
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
|
|
|
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
|
|
|
} // namespace onnx_mlir
|