41de3cb150
Validate Operations / validate-operations (push) Has been cancelled
better reports refactor for more code-reuse and patter usage fixes
185 lines
6.4 KiB
C++
185 lines
6.4 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#include "Common/Common.hpp"
|
|
#include "Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Compiler/CompilerOptions.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
|
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
|
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
|
|
|
ONNXToSpatialPass() = default;
|
|
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static void populateEmptyFunction(func::FuncOp funcOp) {
|
|
IRRewriter rewriter(funcOp.getContext());
|
|
IRMapping mapper;
|
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
|
if (!computes.empty())
|
|
return;
|
|
|
|
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
|
rewriter.setInsertionPoint(returnOp);
|
|
|
|
SmallVector<Type> sourceTypes;
|
|
SmallVector<Location> sourceLocs;
|
|
sourceTypes.reserve(funcOp.getNumArguments());
|
|
sourceLocs.reserve(funcOp.getNumArguments());
|
|
for (Value source : funcOp.getArguments()) {
|
|
sourceTypes.push_back(source.getType());
|
|
sourceLocs.push_back(source.getLoc());
|
|
}
|
|
|
|
auto newCompute = spatial::SpatCompute::create(
|
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
|
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
|
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
|
mapper.map(computeArg, blockArg);
|
|
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
|
|
|
|
rewriter.setInsertionPointToEnd(newBlock);
|
|
for (Operation& op : funcOp.getOps())
|
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
|
rewriter.clone(op, mapper);
|
|
|
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
|
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
|
|
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
|
op.dropAllUses();
|
|
rewriter.eraseOp(&op);
|
|
}
|
|
|
|
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
|
returnOp.setOperand(index, computeResult);
|
|
}
|
|
|
|
void ONNXToSpatialPass::runOnOperation() {
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = &getContext();
|
|
|
|
RewritePatternSet prePatterns(ctx);
|
|
populatePrePatterns(prePatterns, ctx);
|
|
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
|
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<spatial::SpatialDialect,
|
|
ONNXDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
scf::SCFDialect>();
|
|
target.addIllegalOp<ONNXMatMulOp>();
|
|
target.addIllegalOp<ONNXAddOp>();
|
|
target.addIllegalOp<ONNXDivOp>();
|
|
target.addIllegalOp<ONNXMulOp>();
|
|
target.addIllegalOp<ONNXGemmOp>();
|
|
target.addIllegalOp<ONNXConvOp>();
|
|
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
|
target.addIllegalOp<ONNXAveragePoolOp>();
|
|
target.addIllegalOp<ONNXReluOp>();
|
|
target.addIllegalOp<ONNXSigmoidOp>();
|
|
target.addIllegalOp<ONNXSoftmaxOp>();
|
|
target.addIllegalOp<ONNXConcatOp>();
|
|
target.addIllegalOp<ONNXGatherOp>();
|
|
target.addIllegalOp<ONNXReshapeOp>();
|
|
target.addIllegalOp<ONNXResizeOp>();
|
|
target.addIllegalOp<ONNXLRNOp>();
|
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
|
target.addIllegalOp<ONNXSplitOp>();
|
|
|
|
RewritePatternSet conversionPatterns(ctx);
|
|
populateConversionPatterns(conversionPatterns, ctx);
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
RewritePatternSet earlyPostPatterns(ctx);
|
|
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
|
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
if (coresCount != -1) {
|
|
int computeOpsCount = 0;
|
|
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
|
if (isa<spatial::SpatCompute>(op))
|
|
computeOpsCount++;
|
|
|
|
if (computeOpsCount > coresCount) {
|
|
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
|
<< coresCount << ")";
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
PassManager cleanupPM(ctx);
|
|
cleanupPM.addPass(createCanonicalizerPass());
|
|
if (failed(cleanupPM.run(moduleOp)))
|
|
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
|
|
|
|
annotateWeightsConstants(*entryFunc);
|
|
|
|
RewritePatternSet postPatterns(ctx);
|
|
populatePostPatterns(postPatterns, ctx);
|
|
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
populateEmptyFunction(*entryFunc);
|
|
|
|
dumpModule(moduleOp, "spatial0");
|
|
}
|
|
|
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
|
|
|
} // namespace onnx_mlir
|