423 lines
16 KiB
C++
423 lines
16 KiB
C++
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_os_ostream.h"
|
|
|
|
#include <fstream>
|
|
#include <iterator>
|
|
#include <utility>
|
|
|
|
#include "Common.hpp"
|
|
#include "Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Compiler/CompilerOptions.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
bool haveSameStaticShape(Value lhs, Value rhs);
|
|
|
|
namespace {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
|
|
|
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
|
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
|
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
|
|
|
ONNXToSpatialPass() = default;
|
|
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
|
|
|
void runOnOperation() override;
|
|
|
|
private:
|
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
|
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
|
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
|
IRRewriter rewriter(funcOp.getContext());
|
|
SmallVector<spatial::SpatComputeBatch> batchOps;
|
|
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
|
|
|
|
for (auto batchOp : batchOps) {
|
|
if (batchOp.getLaneCount() != 1)
|
|
continue;
|
|
|
|
auto loc = batchOp.getLoc();
|
|
rewriter.setInsertionPoint(batchOp);
|
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
|
computeOp.getProperties().setOperandSegmentSizes(
|
|
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
|
|
|
Block& templateBlock = batchOp.getBody().front();
|
|
SmallVector<Type> blockArgTypes;
|
|
SmallVector<Location> blockArgLocs;
|
|
for (BlockArgument arg : templateBlock.getArguments()) {
|
|
blockArgTypes.push_back(arg.getType());
|
|
blockArgLocs.push_back(loc);
|
|
}
|
|
auto* newBlock = rewriter.createBlock(
|
|
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
|
|
IRMapping mapper;
|
|
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
|
mapper.map(oldArg, newArg);
|
|
rewriter.setInsertionPointToEnd(newBlock);
|
|
for (Operation& op : templateBlock)
|
|
rewriter.clone(op, mapper);
|
|
|
|
batchOp.replaceAllUsesWith(computeOp.getResults());
|
|
rewriter.eraseOp(batchOp);
|
|
}
|
|
}
|
|
|
|
void ONNXToSpatialPass::runOnOperation() {
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = &getContext();
|
|
|
|
RewritePatternSet mergeActivationPatterns(ctx);
|
|
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
|
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
|
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
|
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
|
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
|
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
|
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
|
|
|
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
|
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
|
|
|
IRRewriter rewriter(moduleOp);
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<spatial::SpatialDialect,
|
|
ONNXDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
scf::SCFDialect>();
|
|
target.addIllegalOp<ONNXMatMulOp>();
|
|
target.addIllegalOp<ONNXAddOp>();
|
|
target.addIllegalOp<ONNXDivOp>();
|
|
target.addIllegalOp<ONNXMulOp>();
|
|
target.addIllegalOp<ONNXGemmOp>();
|
|
target.addIllegalOp<ONNXConvOp>();
|
|
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
|
target.addIllegalOp<ONNXAveragePoolOp>();
|
|
target.addIllegalOp<ONNXReluOp>();
|
|
target.addIllegalOp<ONNXSigmoidOp>();
|
|
target.addIllegalOp<ONNXSoftmaxOp>();
|
|
target.addIllegalOp<ONNXConcatOp>();
|
|
target.addIllegalOp<ONNXGatherOp>();
|
|
target.addIllegalOp<ONNXReshapeOp>();
|
|
target.addIllegalOp<ONNXResizeOp>();
|
|
target.addIllegalOp<ONNXLRNOp>();
|
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
|
target.addIllegalOp<ONNXSplitOp>();
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
patterns.add<removeLRN>(ctx);
|
|
|
|
populateElementwisePatterns(patterns, ctx);
|
|
populateGemmPatterns(patterns, ctx);
|
|
populateConvPatterns(patterns, ctx);
|
|
populatePoolPatterns(patterns, ctx);
|
|
populateReduceMeanPatterns(patterns, ctx);
|
|
populateReluPatterns(patterns, ctx);
|
|
populateSigmoidPatterns(patterns, ctx);
|
|
populateSoftmaxPatterns(patterns, ctx);
|
|
populateConcatPatterns(patterns, ctx);
|
|
populateGatherPatterns(patterns, ctx);
|
|
populateResizePatterns(patterns, ctx);
|
|
populateReshapePatterns(patterns, ctx);
|
|
populateSplitPatterns(patterns, ctx);
|
|
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
foldSingleLaneComputeBatches(*entryFunc);
|
|
|
|
// Count the number of compute ops and check they do not exceed the core count
|
|
if (coresCount != -1) {
|
|
int computeOpsCount = 0;
|
|
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
|
if (isa<spatial::SpatCompute>(op))
|
|
computeOpsCount++;
|
|
|
|
if (computeOpsCount > coresCount) {
|
|
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
PassManager cleanupPM(ctx);
|
|
cleanupPM.addPass(createCanonicalizerPass());
|
|
if (failed(cleanupPM.run(moduleOp)))
|
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
|
|
|
annotateWeightsConstants(*entryFunc);
|
|
encapsulateGlobalInstruction(*entryFunc);
|
|
|
|
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Dump to file for debug
|
|
dumpModule(moduleOp, "spatial0");
|
|
}
|
|
|
|
template <typename T>
|
|
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
|
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
|
Value source = funcSource(toRemoveOp);
|
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
|
rewriter.setInsertionPointToEnd(BB);
|
|
IRMapping mapper;
|
|
mapper.map(source, BB->getArgument(0));
|
|
auto newInst = rewriter.clone(*inst, mapper);
|
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
|
inst->replaceAllUsesWith(newCompute->getResults());
|
|
inst->erase();
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
|
auto sources = toRemoveOp.getInputs();
|
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
if (llvm::any_of(sources,
|
|
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
|
SmallVector<Type> sourceTypes;
|
|
SmallVector<Location> sourceLoc;
|
|
for (auto source : sources) {
|
|
sourceTypes.push_back(source.getType());
|
|
sourceLoc.push_back(loc);
|
|
}
|
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
|
rewriter.setInsertionPointToEnd(BB);
|
|
IRMapping mapper;
|
|
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
|
mapper.map(source, bbArg);
|
|
auto newConcat = spatial::SpatConcatOp::create(rewriter,
|
|
loc,
|
|
toRemoveOp.getType(),
|
|
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
|
|
ValueRange(BB->getArguments()));
|
|
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
|
|
inst->replaceAllUsesWith(newCompute->getResults());
|
|
inst->erase();
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
|
if (auto mapped = mapper.lookupOrNull(value))
|
|
return cast<Value>(mapped);
|
|
|
|
Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return failure();
|
|
|
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
|
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
|
if (!tensorType || !tensorType.hasStaticShape())
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
|
sizes.reserve(tensorType.getRank());
|
|
for (int64_t dim : tensorType.getShape())
|
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
|
|
auto referencedValue =
|
|
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
|
mapper.map(value, referencedValue.getResult());
|
|
return referencedValue.getResult();
|
|
}
|
|
|
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
return failure();
|
|
|
|
IRMapping localMapper;
|
|
for (Value operand : definingOp->getOperands()) {
|
|
if (auto mapped = mapper.lookupOrNull(operand)) {
|
|
localMapper.map(operand, cast<Value>(mapped));
|
|
continue;
|
|
}
|
|
|
|
if (isWeightLikeComputeOperand(operand)) {
|
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
|
if (failed(clonedOperand))
|
|
return failure();
|
|
localMapper.map(operand, *clonedOperand);
|
|
continue;
|
|
}
|
|
|
|
localMapper.map(operand, operand);
|
|
}
|
|
|
|
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
|
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
mapper.map(oldResult, newResult);
|
|
|
|
auto mapped = mapper.lookupOrNull(value);
|
|
if (!mapped)
|
|
return failure();
|
|
return cast<Value>(mapped);
|
|
}
|
|
|
|
// TODO what we want to keep in global?
|
|
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|
Location loc = funcOp.getLoc();
|
|
IRRewriter rewriter(&getContext());
|
|
bool keep = true;
|
|
while (keep) {
|
|
keep = false;
|
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
|
keep |= encapsulator<tensor::ExtractSliceOp>(
|
|
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
|
|
|
keep |= encapsulator<tensor::ExpandShapeOp>(
|
|
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
|
|
|
keep |= encapsulator<ONNXTransposeOp>(
|
|
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
|
|
|
keep |= encapsulator<tensor::CollapseShapeOp>(
|
|
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
|
|
|
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
|
}
|
|
}
|
|
}
|
|
|
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
|
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
|
markWeightAlways(constantOp);
|
|
});
|
|
}
|
|
|
|
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
|
IRRewriter rewriter(&getContext());
|
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
|
|
|
for (auto compute : computes) {
|
|
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
|
bool needsRewrite = false;
|
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
if (!isWeightLikeComputeOperand(input))
|
|
continue;
|
|
promoteInput[inputIdx] = true;
|
|
needsRewrite = true;
|
|
}
|
|
if (!needsRewrite)
|
|
continue;
|
|
|
|
rewriter.setInsertionPointAfter(compute);
|
|
|
|
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
|
SmallVector<Value> newInputs;
|
|
SmallVector<Type> newInputTypes;
|
|
SmallVector<Location> newInputLocs;
|
|
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
|
newInputs.reserve(compute.getInputs().size());
|
|
newInputTypes.reserve(compute.getInputs().size());
|
|
newInputLocs.reserve(compute.getInputs().size());
|
|
|
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
if (promoteInput[inputIdx]) {
|
|
newWeights.push_back(input);
|
|
continue;
|
|
}
|
|
newInputs.push_back(input);
|
|
newInputTypes.push_back(input.getType());
|
|
newInputLocs.push_back(input.getLoc());
|
|
}
|
|
|
|
auto newCompute =
|
|
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
|
auto* newBlock =
|
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
|
newCompute.getProperties().setOperandSegmentSizes(
|
|
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
|
rewriter.setInsertionPointToStart(newBlock);
|
|
|
|
IRMapping mapper;
|
|
auto& oldBlock = compute.getBody().front();
|
|
size_t newInputIdx = 0;
|
|
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
|
if (!promoteInput[oldInputIdx]) {
|
|
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
|
continue;
|
|
}
|
|
|
|
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
|
|
if (failed(clonedValue))
|
|
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
|
|
mapper.map(oldArg, *clonedValue);
|
|
}
|
|
|
|
for (auto& op : oldBlock.without_terminator())
|
|
rewriter.clone(op, mapper);
|
|
|
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
SmallVector<Value> newYieldOperands;
|
|
newYieldOperands.reserve(oldYield.getOutputs().size());
|
|
for (Value operand : oldYield.getOutputs()) {
|
|
auto mapped = mapper.lookupOrNull(operand);
|
|
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
|
}
|
|
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
|
|
|
compute.replaceAllUsesWith(newCompute);
|
|
compute.erase();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
|
|
|
} // namespace onnx_mlir
|