diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index 3ecc059..dd6b418 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -39,16 +39,16 @@ void addPassesPim(OwningOpRef& module, } if (pimEmissionTarget >= EmitPimBufferized) { - pm.addPass(createBufferizePimPass()); + pm.addPass(createPimBufferizationPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim bufferized")); } if (pimEmissionTarget >= EmitPimCodegen) { - pm.addPass(createConstantFoldingPass()); + pm.addPass(createPimConstantFoldingPass()); pm.addPass(createMessagePass("Pim constants folded")); - pm.addPass(createMaterializeConstantsPass()); - pm.addPass(createVerificationPass()); + pm.addPass(createPimMaterializeConstantsPass()); + pm.addPass(createPimVerificationPass()); pm.addPass(createMessagePass("Pim verified")); pm.addPass(createEmitPimJsonPass()); // pm.addPass(createCountInstructionPass()); diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 1e383c8..245ff6f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -3,10 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(ONNXToSpatialIncGen) add_pim_library(OMONNXToSpatial - Patterns/Math/Gemm.cpp Patterns/Math/Conv.cpp + Patterns/Math/Gemm.cpp Patterns/Math/MatMul.cpp Patterns/NN/Pool.cpp + Patterns/NN/Relu.cpp Patterns/Tensor/Concat.cpp Patterns/Tensor/Reshape.cpp ONNXToSpatialPass.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp index 2e2b801..864148d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.hpp @@ -1,15 +1,13 @@ #pragma once #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/LogicalResult.h" - #include -#include #include #include @@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) { using HSliceId = size_t; using CoreId = size_t; -enum class MapOperations { - None, - ONNXSoftmaxOp, - ONNXReluOp, - ONNXLeakyReluOp, - ONNXExpOp -}; - template > constexpr C ceilIntegerDivide(A a, B b) { static_assert(std::is_integral_v, "A must be an integer type"); @@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) { return mlir::cast(tensor.getType()).getShape(); } +namespace detail { + +template +void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { + std::forward(fn)(block->getArgument(Is)...); +} + +} // namespace detail + +template +spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); + auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); + + auto* block = new mlir::Block(); + for (mlir::Value input : inputs) + block->addArgument(input.getType(), loc); + + computeOp.getBody().push_back(block); + rewriter.setInsertionPointToStart(block); + + detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + + rewriter.setInsertionPointAfter(computeOp); + return computeOp; +} + llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, size_t axis, int64_t sliceSize, diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td index cb3401b..1789e45 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td @@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td" include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" #endif // OP_BASE -def onnxToArithConstantOp : Pat< +def onnxToArithConstant : Pat< (ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings), (Arith_ConstantOp $value) >; @@ -19,7 +19,7 @@ def IsRank2Result: Constraint< CPred<"cast($0.getType()).getRank() == 2">, "Result is rank 2">; -def matMulAddToGemmPattern : Pat< +def matMulAddToGemm : Pat< (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXGemmOp $A, $B, $C, /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), @@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat< [(IsRank2Result $matmulres)] >; -def matMulToGemmPattern : Pat< +def matMulToGemm : Pat< (ONNXMatMulOp:$matmulres $A, $B), ( ONNXGemmOp $A, $B, @@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat< // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern -// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single -// ONNXConvOp with a bias. -def convAddToConvWithBiasPatternLeft : Pat< +// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias. +def convAddToConvWithBiasLeft : Pat< (ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)), (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) >; -def convAddToConvWithBiasPatternRight : Pat< +def convAddToConvWithBiasRight : Pat< (ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand), (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) >; @@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat< def replaceWithOperationOfValue : NativeCodeCall<"$0">; -def removeLRNPattern : Pat< +def removeLRN : Pat< (ONNXLRNOp $A, $_, $_, $_, $_), (replaceWithOperationOfValue $A) >; @@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint< CPred<"onnx_mlir::haveSameStaticShape($0, $1)">, "Two tensors have the same static shape">; -def removeFlattenSameShapePattern : Pat< +def removeFlattenSameShape : Pat< (ONNXFlattenOp:$flattenOp $A, $axis), (replaceWithOperationOfValue $A), [(HaveSameStaticShape $flattenOp, $A)] ->; // Add closing parenthesis here +>; #endif // ONNX_TO_SPATIAL diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 960939f..a6dc98f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() { MLIRContext* ctx = &getContext(); RewritePatternSet mergeActivationPatterns(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); populateMatMulRewritePatterns(mergeActivationPatterns, ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) @@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() { [](ONNXMatMulOp op) { return cast(op.getY().getType()).getRank() != 2; }); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); - populateConvOpPatterns(patterns, ctx); - populatePoolTilingPattern(patterns, ctx); - populateOnnxGemmOpPatterns(patterns, ctx); - populateReshapeConversionPattern(patterns, ctx); - - populateONNXConcatToTensorConcatPattern(patterns, ctx); + populateGemmPatterns(patterns, ctx); + populateConvPatterns(patterns, ctx); + populatePoolPatterns(patterns, ctx); + populateReluPatterns(patterns, ctx); + populateConcatPatterns(patterns, ctx); + populateReshapePatterns(patterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index de4ae31..58f9a10 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -5,16 +5,18 @@ namespace onnx_mlir { -void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 981510b..09eb23e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return success(); } -void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } +void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index d8c7cb0..7ac5bfb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern { }; struct GemvToSpatialCompute : OpConversionPattern { - GemvToSpatialCompute(MLIRContext* ctx) - : OpConversionPattern(ctx, 1) {} + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, @@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, return success(); } -void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { +void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); patterns.insert(ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index f558c47..281eab0 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -257,7 +257,7 @@ struct PoolToSpatialCompute } // namespace -void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { +void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert>(ctx); patterns.insert>(ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp new file mode 100644 index 0000000..b922581 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp @@ -0,0 +1,33 @@ +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct ReluToSpatialCompute : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + Location loc = reluOp.getLoc(); + Type resultType = reluOp.getResult().getType(); + constexpr size_t numInputs = 1; + auto computeOp = createSpatCompute(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) { + auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x); + spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult()); + }); + rewriter.replaceOp(reluOp, computeOp); + return success(); + } +}; + +} // namespace + +void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index 36ffc95..84271a8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -9,8 +9,7 @@ using namespace mlir; namespace onnx_mlir { struct Concat : public OpConversionPattern { - Concat(MLIRContext* ctx) - : OpConversionPattern(ctx) {} + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, ONNXConcatOpAdaptor adaptor, @@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern { } }; -void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) { +void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index befa1c5..4499c7c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern { } // namespace -void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } +void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 88eadbc..de37037 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -9,41 +9,46 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td" #endif // OP_BASE -def onnxToPimTransposeOp : Pat< +def onnxToPimTranspose : Pat< (ONNXTransposeOp:$srcOpRes $data, $perms), (PimTransposeOp $data, $perms, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimVMMOp : Pat< +def spatToPimVMM : Pat< (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), (PimVMMOp $weightIndex, $vector, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimMVMOp : Pat< +def spatToPimMVM : Pat< (SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector), (PimMVMOp $weightIndex, $vector, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimVVAddOp : Pat< +def spatToPimVVAdd : Pat< (SpatVAddOp:$srcOpRes $a, $b), (PimVVAddOp $a, $b, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimVVMulOp : Pat< +def spatToPimVVMul : Pat< (SpatVMulOp:$srcOpRes $a, $b), (PimVVMulOp $a, $b, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimVVMaxOp : Pat< +def spatToPimVVMax : Pat< (SpatVMaxOp:$srcOpRes $a, $b), (PimVVMaxOp $a, $b, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; +def spatToPimVRelu : Pat< + (SpatReluOp:$srcOpRes $input), + (PimVReluOp $input, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 616da86..300ddd1 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -363,29 +363,6 @@ def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> { }]; } -def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> { - let summary = "Reduce all elements to a single value"; - - let arguments = (ins - PimTensor:$input, - PimTensor:$outputBuffer - ); - - let results = (outs - PimTensor:$output - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutputBufferMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) - }]; -} - def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> { let summary = "Average all elements into a single value"; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 780d1e2..5d29208 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -97,8 +97,7 @@ struct MemCopyDevToHostOpInterface } }; -struct TransposeOpBufferizeInterface -: DstBufferizableOpInterfaceExternalModel { +struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); } @@ -123,7 +122,7 @@ struct TransposeOpBufferizeInterface } }; -struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { +struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); } @@ -160,7 +159,7 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { +struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); } @@ -186,8 +185,7 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel -struct BinaryDstOpBufferizeInterface -: DstBufferizableOpInterfaceExternalModel, OpTy> { +struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); } @@ -225,17 +223,56 @@ struct BinaryDstOpBufferizeInterface } }; +template +struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel, OpTy> { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val, + const AnalysisState& state, + ArrayRef opOperands) const { + return true; + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto unaryOp = cast(op); + + auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state); + if (failed(inputOpt)) + return failure(); + + auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter); + + replaceOpWithNewBufferizedOp(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt); + return success(); + } +}; + void registerOpBufferizationInterfaces(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { PimMemCopyHostToDevOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); - PimTransposeOp::attachInterface(*ctx); - PimVMMOp::attachInterface(*ctx); - PimMVMOp::attachInterface(*ctx); - PimVVAddOp::attachInterface>(*ctx); - PimVVSubOp::attachInterface>(*ctx); - PimVVMulOp::attachInterface>(*ctx); - PimVVMaxOp::attachInterface>(*ctx); + PimTransposeOp::attachInterface(*ctx); + PimVMMOp::attachInterface(*ctx); + PimMVMOp::attachInterface(*ctx); + + PimVVAddOp::attachInterface>(*ctx); + PimVVSubOp::attachInterface>(*ctx); + PimVVMulOp::attachInterface>(*ctx); + PimVVMaxOp::attachInterface>(*ctx); + + PimVAvgOp::attachInterface>(*ctx); + PimVReluOp::attachInterface>(*ctx); + PimVTanhOp::attachInterface>(*ctx); + PimVSigmOp::attachInterface>(*ctx); }); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td index f5040a4..bc920e3 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td @@ -16,5 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat< (returnType $dst)) >; - #endif // PIM_BUFFERIZATION diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 10da3ab..a2b99a2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -105,6 +105,6 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO }); } -std::unique_ptr createBufferizePimPass() { return std::make_unique(); } +std::unique_ptr createPimBufferizationPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index a27cc89..5d0869f 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -486,8 +486,6 @@ struct WVMMOpInterface : WeightedMultiplicationsOpInterface {}; -struct SumOpInterface : VariadicArgumentElementWiseOpInterface {}; - struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface {}; void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { @@ -496,7 +494,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { SpatVAddOp::attachInterface(*ctx); SpatWeightedVMMOp::attachInterface(*ctx); SpatWeightedMVMOp::attachInterface(*ctx); - SpatSumOp::attachInterface(*ctx); SpatVMaxOp::attachInterface(*ctx); SpatChannelReceiveOp::attachInterface(*ctx); SpatChannelSendOp::attachInterface(*ctx); diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index 3285a48..ba0724c 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -13,13 +13,13 @@ std::unique_ptr createSpatialToGraphvizPass(); std::unique_ptr createSpatialToPimPass(); -std::unique_ptr createBufferizePimPass(); +std::unique_ptr createPimBufferizationPass(); -std::unique_ptr createConstantFoldingPass(); +std::unique_ptr createPimConstantFoldingPass(); -std::unique_ptr createMaterializeConstantsPass(); +std::unique_ptr createPimMaterializeConstantsPass(); -std::unique_ptr createVerificationPass(); +std::unique_ptr createPimVerificationPass(); std::unique_ptr createEmitPimJsonPass(); diff --git a/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp b/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp index e237de4..76bd097 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp @@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper createConstantFoldingPass() { return std::make_unique(); } +std::unique_ptr createPimConstantFoldingPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/MaterializeConstantsPass.cpp b/src/PIM/Pass/Pim/MaterializeConstantsPass.cpp index bf9b084..fdb9904 100644 --- a/src/PIM/Pass/Pim/MaterializeConstantsPass.cpp +++ b/src/PIM/Pass/Pim/MaterializeConstantsPass.cpp @@ -126,6 +126,6 @@ struct MaterializeConstantsPass : PassWrapper createMaterializeConstantsPass() { return std::make_unique(); } +std::unique_ptr createPimMaterializeConstantsPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/VerificationPass.cpp b/src/PIM/Pass/Pim/VerificationPass.cpp index 4661744..809685a 100644 --- a/src/PIM/Pass/Pim/VerificationPass.cpp +++ b/src/PIM/Pass/Pim/VerificationPass.cpp @@ -189,6 +189,6 @@ private: } // namespace -std::unique_ptr createVerificationPass() { return std::make_unique(); } +std::unique_ptr createPimVerificationPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index aaa82c5..4541898 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -19,6 +19,7 @@ #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/PimAccelerator.hpp" +#include "src/Compiler/CompilerUtils.hpp" #define DEBUG_TYPE "PimAccelerator" @@ -69,13 +70,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const { void PimAccelerator::registerPasses(int optLevel) const { LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); + TOTAL_COMPILE_PHASE = 8; registerPass(createONNXToSpatialPass); registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPimPass); - registerPass(createBufferizePimPass); - registerPass(createConstantFoldingPass); - registerPass(createMaterializeConstantsPass); - registerPass(createVerificationPass); + registerPass(createPimBufferizationPass); + registerPass(createPimConstantFoldingPass); + registerPass(createPimMaterializeConstantsPass); + registerPass(createPimVerificationPass); registerPass(createEmitPimJsonPass); } diff --git a/validation/operations/README.md b/validation/operations/README.md index cd6e88e..6594703 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -35,6 +35,15 @@ python3 validation/operations/gen_tests.py | Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` | | Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` | +## Relu + +| Test | Directory | Input | Output | Notes | +|------|-----------|-------|--------|-------| +| Basic | `relu/basic` | [4,8] | [4,8] | Standalone 2D Relu | +| 4D | `relu/4d` | [2,3,4,4] | [2,3,4,4] | Standalone NCHW Relu | +| After Conv | `relu/after_conv` | [1,3,5,5] | [1,2,3,3] | Conv 3x3 + bias, then Relu | +| After Gemm | `relu/after_gemm` | [4,64] | [4,32] | Gemm + bias, then Relu | + ## Gemm | Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 526c4c9..777725b 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations.""" +"""Generate ONNX test models for validating GEMM, Conv, Pooling, and Relu implementations.""" import numpy as np import onnx @@ -327,6 +327,60 @@ def maxpool_after_conv(): save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx") +# --------------------------------------------------------------------------- +# Relu tests +# --------------------------------------------------------------------------- + +def relu_basic(): + """Standalone Relu on a simple 2D tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8]) + node = helper.make_node("Relu", ["X"], ["Y"]) + graph = helper.make_graph([node], "relu_basic", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "relu/basic", "relu_basic.onnx") + + +def relu_4d(): + """Standalone Relu on an NCHW tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 4]) + node = helper.make_node("Relu", ["X"], ["Y"]) + graph = helper.make_graph([node], "relu_4d", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "relu/4d", "relu_4d.onnx") + + +def relu_after_conv(): + """Conv followed by Relu.""" + rng = np.random.default_rng(60) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3]) + W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W") + B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B") + conv = helper.make_node("Conv", ["X", "W", "B"], ["C"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + relu = helper.make_node("Relu", ["C"], ["Y"]) + graph = helper.make_graph([conv, relu], "relu_after_conv", [X], [Y], initializer=[W, B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "relu/after_conv", "relu_after_conv.onnx") + + +def relu_after_gemm(): + """Gemm followed by Relu.""" + B, K, N = 4, 64, 32 + rng = np.random.default_rng(61) + W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W") + C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"]) + relu = helper.make_node("Relu", ["G"], ["Y"]) + graph = helper.make_graph([gemm, relu], "relu_after_gemm", [A], [Y], initializer=[W, C]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "relu/after_gemm", "relu_after_gemm.onnx") + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -361,4 +415,10 @@ if __name__ == "__main__": avgpool_include_pad() maxpool_after_conv() + print("\nGenerating Relu tests:") + relu_basic() + relu_4d() + relu_after_conv() + relu_after_gemm() + print("\nDone.") diff --git a/validation/operations/relu/4d/relu_4d.onnx b/validation/operations/relu/4d/relu_4d.onnx new file mode 100644 index 0000000..3d62ef7 Binary files /dev/null and b/validation/operations/relu/4d/relu_4d.onnx differ diff --git a/validation/operations/relu/after_conv/relu_after_conv.onnx b/validation/operations/relu/after_conv/relu_after_conv.onnx new file mode 100644 index 0000000..d2cb581 Binary files /dev/null and b/validation/operations/relu/after_conv/relu_after_conv.onnx differ diff --git a/validation/operations/relu/after_gemm/relu_after_gemm.onnx b/validation/operations/relu/after_gemm/relu_after_gemm.onnx new file mode 100644 index 0000000..0c80915 Binary files /dev/null and b/validation/operations/relu/after_gemm/relu_after_gemm.onnx differ diff --git a/validation/operations/relu/basic/relu_basic.onnx b/validation/operations/relu/basic/relu_basic.onnx new file mode 100644 index 0000000..119f330 Binary files /dev/null and b/validation/operations/relu/basic/relu_basic.onnx differ