add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s

add relu validation
add spatial compute helper
minor refactors
This commit is contained in:
NiccoloN
2026-03-25 11:03:03 +01:00
parent 4e19650b80
commit 742df111e3
29 changed files with 258 additions and 116 deletions

View File

@@ -39,16 +39,16 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createBufferizePimPass());
pm.addPass(createPimBufferizationPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createConstantFoldingPass());
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createMaterializeConstantsPass());
pm.addPass(createVerificationPass());
pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());

View File

@@ -3,10 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
Patterns/Math/Gemm.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Gemm.cpp
Patterns/Math/MatMul.cpp
Patterns/NN/Pool.cpp
Patterns/NN/Relu.cpp
Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp
ONNXToSpatialPass.cpp

View File

@@ -1,15 +1,13 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <optional>
#include <type_traits>
#include <utility>
@@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) {
using HSliceId = size_t;
using CoreId = size_t;
enum class MapOperations {
None,
ONNXSoftmaxOp,
ONNXReluOp,
ONNXLeakyReluOp,
ONNXExpOp
};
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
@@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
namespace detail {
template <typename Fn, size_t... Is>
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
std::forward<Fn>(fn)(block->getArgument(Is)...);
}
} // namespace detail
template <size_t NumInputs, typename BodyFn>
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,

View File

@@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
#endif // OP_BASE
def onnxToArithConstantOp : Pat<
def onnxToArithConstant : Pat<
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
(Arith_ConstantOp $value)
>;
@@ -19,7 +19,7 @@ def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">;
def matMulAddToGemmPattern : Pat<
def matMulAddToGemm : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C,
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
@@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat<
[(IsRank2Result $matmulres)]
>;
def matMulToGemmPattern : Pat<
def matMulToGemm : Pat<
(ONNXMatMulOp:$matmulres $A, $B),
(
ONNXGemmOp $A, $B,
@@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat<
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias.
def convAddToConvWithBiasPatternLeft : Pat<
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
def convAddToConvWithBiasLeft : Pat<
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
def convAddToConvWithBiasPatternRight : Pat<
def convAddToConvWithBiasRight : Pat<
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
@@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat<
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
def removeLRNPattern : Pat<
def removeLRN : Pat<
(ONNXLRNOp $A, $_, $_, $_, $_),
(replaceWithOperationOfValue $A)
>;
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
"Two tensors have the same static shape">;
def removeFlattenSameShapePattern : Pat<
def removeFlattenSameShape : Pat<
(ONNXFlattenOp:$flattenOp $A, $axis),
(replaceWithOperationOfValue $A),
[(HaveSameStaticShape $flattenOp, $A)]
>; // Add closing parenthesis here
>;
#endif // ONNX_TO_SPATIAL

View File

@@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() {
MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
@@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() {
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
target.addIllegalOp<ONNXAveragePoolOp>();
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXReluOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
patterns.add<removeLRN>(ctx);
populateConvOpPatterns(patterns, ctx);
populatePoolTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();

View File

@@ -5,16 +5,18 @@
namespace onnx_mlir {
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return success();
}
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
} // namespace onnx_mlir

View File

@@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
@@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
return success();
}
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}

View File

@@ -257,7 +257,7 @@ struct PoolToSpatialCompute<ONNXAveragePoolOp>
} // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
}

View File

@@ -0,0 +1,33 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
return success();
}
};
} // namespace
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
} // namespace onnx_mlir

View File

@@ -9,8 +9,7 @@ using namespace mlir;
namespace onnx_mlir {
struct Concat : public OpConversionPattern<ONNXConcatOp> {
Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
ONNXConcatOpAdaptor adaptor,
@@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
}
};
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Concat>(ctx);
}

View File

@@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
} // namespace onnx_mlir

View File

@@ -9,41 +9,46 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def onnxToPimTransposeOp : Pat<
def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMMOp : Pat<
def spatToPimVMM : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVMOp : Pat<
def spatToPimMVM : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVAddOp : Pat<
def spatToPimVVAdd : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMulOp : Pat<
def spatToPimVVMul : Pat<
(SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMaxOp : Pat<
def spatToPimVVMax : Pat<
(SpatVMaxOp:$srcOpRes $a, $b),
(PimVVMaxOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVRelu : Pat<
(SpatReluOp:$srcOpRes $input),
(PimVReluOp $input,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
#endif // SPATIAL_TO_PIM

View File

@@ -363,29 +363,6 @@ def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
}];
}
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
let summary = "Reduce all elements to a single value";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
let summary = "Average all elements into a single value";

View File

@@ -97,8 +97,7 @@ struct MemCopyDevToHostOpInterface
}
};
struct TransposeOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
@@ -123,7 +122,7 @@ struct TransposeOpBufferizeInterface
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
@@ -160,7 +159,7 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
}
};
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
@@ -186,8 +185,7 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
};
template <typename OpTy>
struct BinaryDstOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
@@ -225,17 +223,56 @@ struct BinaryDstOpBufferizeInterface
}
};
template <typename OpTy>
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
const AnalysisState& state,
ArrayRef<OpOperand*> opOperands) const {
return true;
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto unaryOp = cast<OpTy>(op);
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
return success();
}
};
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
});
}

View File

@@ -16,5 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION

View File

@@ -105,6 +105,6 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
});
}
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -486,8 +486,6 @@ struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, Spa
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
@@ -496,7 +494,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);

View File

@@ -13,13 +13,13 @@ std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createVerificationPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
} // namespace
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -126,6 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
} // namespace onnx_mlir

View File

@@ -189,6 +189,6 @@ private:
} // namespace
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir

View File

@@ -19,6 +19,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/PimAccelerator.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#define DEBUG_TYPE "PimAccelerator"
@@ -69,13 +70,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
TOTAL_COMPILE_PHASE = 8;
registerPass(createONNXToSpatialPass);
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass);
registerPass(createConstantFoldingPass);
registerPass(createMaterializeConstantsPass);
registerPass(createVerificationPass);
registerPass(createPimBufferizationPass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass);
}