add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s

add relu validation
add spatial compute helper
minor refactors
This commit is contained in:
NiccoloN
2026-03-25 11:03:03 +01:00
parent 4e19650b80
commit 742df111e3
29 changed files with 258 additions and 116 deletions

View File

@@ -39,16 +39,16 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPimBufferized) { if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createBufferizePimPass()); pm.addPass(createPimBufferizationPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized")); pm.addPass(createMessagePass("Pim bufferized"));
} }
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createConstantFoldingPass()); pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded")); pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createMaterializeConstantsPass()); pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createVerificationPass()); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified")); pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());

View File

@@ -3,10 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial add_pim_library(OMONNXToSpatial
Patterns/Math/Gemm.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/Gemm.cpp
Patterns/Math/MatMul.cpp Patterns/Math/MatMul.cpp
Patterns/NN/Pool.cpp Patterns/NN/Pool.cpp
Patterns/NN/Relu.cpp
Patterns/Tensor/Concat.cpp Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp

View File

@@ -1,15 +1,13 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert> #include <cassert>
#include <optional>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
@@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) {
using HSliceId = size_t; using HSliceId = size_t;
using CoreId = size_t; using CoreId = size_t;
enum class MapOperations {
None,
ONNXSoftmaxOp,
ONNXReluOp,
ONNXLeakyReluOp,
ONNXExpOp
};
template <class A, class B, class C = std::common_type_t<A, B>> template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) { constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type"); static_assert(std::is_integral_v<A>, "A must be an integer type");
@@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape(); return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
} }
namespace detail {
template <typename Fn, size_t... Is>
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
std::forward<Fn>(fn)(block->getArgument(Is)...);
}
} // namespace detail
template <size_t NumInputs, typename BodyFn>
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice, llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis, size_t axis,
int64_t sliceSize, int64_t sliceSize,

View File

@@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
#endif // OP_BASE #endif // OP_BASE
def onnxToArithConstantOp : Pat< def onnxToArithConstant : Pat<
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings), (ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
(Arith_ConstantOp $value) (Arith_ConstantOp $value)
>; >;
@@ -19,7 +19,7 @@ def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">, CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">; "Result is rank 2">;
def matMulAddToGemmPattern : Pat< def matMulAddToGemm : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C, (ONNXGemmOp $A, $B, $C,
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
@@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat<
[(IsRank2Result $matmulres)] [(IsRank2Result $matmulres)]
>; >;
def matMulToGemmPattern : Pat< def matMulToGemm : Pat<
(ONNXMatMulOp:$matmulres $A, $B), (ONNXMatMulOp:$matmulres $A, $B),
( (
ONNXGemmOp $A, $B, ONNXGemmOp $A, $B,
@@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat<
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single // This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
// ONNXConvOp with a bias. def convAddToConvWithBiasLeft : Pat<
def convAddToConvWithBiasPatternLeft : Pat<
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)), (ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>; >;
def convAddToConvWithBiasPatternRight : Pat< def convAddToConvWithBiasRight : Pat<
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand), (ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>; >;
@@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat<
def replaceWithOperationOfValue : NativeCodeCall<"$0">; def replaceWithOperationOfValue : NativeCodeCall<"$0">;
def removeLRNPattern : Pat< def removeLRN : Pat<
(ONNXLRNOp $A, $_, $_, $_, $_), (ONNXLRNOp $A, $_, $_, $_, $_),
(replaceWithOperationOfValue $A) (replaceWithOperationOfValue $A)
>; >;
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">, CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
"Two tensors have the same static shape">; "Two tensors have the same static shape">;
def removeFlattenSameShapePattern : Pat< def removeFlattenSameShape : Pat<
(ONNXFlattenOp:$flattenOp $A, $axis), (ONNXFlattenOp:$flattenOp $A, $axis),
(replaceWithOperationOfValue $A), (replaceWithOperationOfValue $A),
[(HaveSameStaticShape $flattenOp, $A)] [(HaveSameStaticShape $flattenOp, $A)]
>; // Add closing parenthesis here >;
#endif // ONNX_TO_SPATIAL #endif // ONNX_TO_SPATIAL

View File

@@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() {
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx); RewritePatternSet mergeActivationPatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx); mergeActivationPatterns.add<onnxToArithConstant>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx); mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx); mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx); mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx); mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx); mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx); populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
@@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() {
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; }); [](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXGemmOp>(); target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>(); target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXMaxPoolSingleOutOp>(); target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
target.addIllegalOp<ONNXAveragePoolOp>(); target.addIllegalOp<ONNXAveragePoolOp>();
target.addIllegalOp<ONNXConcatOp>(); target.addIllegalOp<ONNXReluOp>();
target.addIllegalOp<ONNXSoftmaxOp>(); target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>(); target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXReshapeOp>(); target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx); patterns.add<removeLRN>(ctx);
populateConvOpPatterns(patterns, ctx); populateGemmPatterns(patterns, ctx);
populatePoolTilingPattern(patterns, ctx); populateConvPatterns(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx); populatePoolPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx); populateReluPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx); populateReshapePatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();

View File

@@ -5,16 +5,18 @@
namespace onnx_mlir { namespace onnx_mlir {
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return success(); return success();
} }
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); } void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
}; };
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> { struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx) using OpConversionPattern::OpConversionPattern;
: OpConversionPattern(ctx, 1) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor, ONNXGemmOpAdaptor gemmOpAdaptor,
@@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
return success(); return success();
} }
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx); patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx); patterns.insert<GemvToSpatialCompute>(ctx);
} }

View File

@@ -257,7 +257,7 @@ struct PoolToSpatialCompute<ONNXAveragePoolOp>
} // namespace } // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx); patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx); patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
} }

View File

@@ -0,0 +1,33 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
return success();
}
};
} // namespace
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
} // namespace onnx_mlir

View File

@@ -9,8 +9,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
struct Concat : public OpConversionPattern<ONNXConcatOp> { struct Concat : public OpConversionPattern<ONNXConcatOp> {
Concat(MLIRContext* ctx) using OpConversionPattern::OpConversionPattern;
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
ONNXConcatOpAdaptor adaptor, ONNXConcatOpAdaptor adaptor,
@@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
} }
}; };
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Concat>(ctx); patterns.insert<Concat>(ctx);
} }

View File

@@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace } // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); } void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -9,41 +9,46 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def onnxToPimTransposeOp : Pat< def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms), (ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms, (PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVMMOp : Pat< def spatToPimVMM : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector, (PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimMVMOp : Pat< def spatToPimMVM : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector), (SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector, (PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVVAddOp : Pat< def spatToPimVVAdd : Pat<
(SpatVAddOp:$srcOpRes $a, $b), (SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b, (PimVVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVVMulOp : Pat< def spatToPimVVMul : Pat<
(SpatVMulOp:$srcOpRes $a, $b), (SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b, (PimVVMulOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVVMaxOp : Pat< def spatToPimVVMax : Pat<
(SpatVMaxOp:$srcOpRes $a, $b), (SpatVMaxOp:$srcOpRes $a, $b),
(PimVVMaxOp $a, $b, (PimVVMaxOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVRelu : Pat<
(SpatReluOp:$srcOpRes $input),
(PimVReluOp $input,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
#endif // SPATIAL_TO_PIM #endif // SPATIAL_TO_PIM

View File

@@ -363,29 +363,6 @@ def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
let summary = "Reduce all elements to a single value";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> { def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
let summary = "Average all elements into a single value"; let summary = "Average all elements into a single value";

View File

@@ -97,8 +97,7 @@ struct MemCopyDevToHostOpInterface
} }
}; };
struct TransposeOpBufferizeInterface struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
@@ -123,7 +122,7 @@ struct TransposeOpBufferizeInterface
} }
}; };
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> { struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
@@ -160,7 +159,7 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
} }
}; };
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> { struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
@@ -186,8 +185,7 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
}; };
template <typename OpTy> template <typename OpTy>
struct BinaryDstOpBufferizeInterface struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
@@ -225,17 +223,56 @@ struct BinaryDstOpBufferizeInterface
} }
}; };
template <typename OpTy>
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
const AnalysisState& state,
ArrayRef<OpOperand*> opOperands) const {
return true;
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto unaryOp = cast<OpTy>(op);
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
return success();
}
};
void registerOpBufferizationInterfaces(DialectRegistry& registry) { void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx); PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx); PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx); PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx); PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx); PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
}); });
} }

View File

@@ -16,5 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
(returnType $dst)) (returnType $dst))
>; >;
#endif // PIM_BUFFERIZATION #endif // PIM_BUFFERIZATION

View File

@@ -105,6 +105,6 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
}); });
} }
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); } std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -486,8 +486,6 @@ struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, Spa
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {}; struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {}; struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
@@ -496,7 +494,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx); SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx); SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx); SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx); SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx); SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx); SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);

View File

@@ -13,13 +13,13 @@ std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<mlir::Pass> createSpatialToPimPass(); std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass(); std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createConstantFoldingPass(); std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass(); std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createVerificationPass(); std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass(); std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
} // namespace } // namespace
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); } std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -126,6 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace } // namespace
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); } std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -189,6 +189,6 @@ private:
} // namespace } // namespace
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); } std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -19,6 +19,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/PimAccelerator.hpp" #include "src/Accelerators/PIM/PimAccelerator.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#define DEBUG_TYPE "PimAccelerator" #define DEBUG_TYPE "PimAccelerator"
@@ -69,13 +70,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
void PimAccelerator::registerPasses(int optLevel) const { void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
TOTAL_COMPILE_PHASE = 8;
registerPass(createONNXToSpatialPass); registerPass(createONNXToSpatialPass);
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass); registerPass(createPimBufferizationPass);
registerPass(createConstantFoldingPass); registerPass(createPimConstantFoldingPass);
registerPass(createMaterializeConstantsPass); registerPass(createPimMaterializeConstantsPass);
registerPass(createVerificationPass); registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }

View File

@@ -35,6 +35,15 @@ python3 validation/operations/gen_tests.py
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` | | Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` | | Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
## Relu
| Test | Directory | Input | Output | Notes |
|------|-----------|-------|--------|-------|
| Basic | `relu/basic` | [4,8] | [4,8] | Standalone 2D Relu |
| 4D | `relu/4d` | [2,3,4,4] | [2,3,4,4] | Standalone NCHW Relu |
| After Conv | `relu/after_conv` | [1,3,5,5] | [1,2,3,3] | Conv 3x3 + bias, then Relu |
| After Gemm | `relu/after_gemm` | [4,64] | [4,32] | Gemm + bias, then Relu |
## Gemm ## Gemm
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | | Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations.""" """Generate ONNX test models for validating GEMM, Conv, Pooling, and Relu implementations."""
import numpy as np import numpy as np
import onnx import onnx
@@ -327,6 +327,60 @@ def maxpool_after_conv():
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx") save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
# ---------------------------------------------------------------------------
# Relu tests
# ---------------------------------------------------------------------------
def relu_basic():
"""Standalone Relu on a simple 2D tensor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
node = helper.make_node("Relu", ["X"], ["Y"])
graph = helper.make_graph([node], "relu_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "relu/basic", "relu_basic.onnx")
def relu_4d():
"""Standalone Relu on an NCHW tensor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 4])
node = helper.make_node("Relu", ["X"], ["Y"])
graph = helper.make_graph([node], "relu_4d", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "relu/4d", "relu_4d.onnx")
def relu_after_conv():
"""Conv followed by Relu."""
rng = np.random.default_rng(60)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
conv = helper.make_node("Conv", ["X", "W", "B"], ["C"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
relu = helper.make_node("Relu", ["C"], ["Y"])
graph = helper.make_graph([conv, relu], "relu_after_conv", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "relu/after_conv", "relu_after_conv.onnx")
def relu_after_gemm():
"""Gemm followed by Relu."""
B, K, N = 4, 64, 32
rng = np.random.default_rng(61)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"])
relu = helper.make_node("Relu", ["G"], ["Y"])
graph = helper.make_graph([gemm, relu], "relu_after_gemm", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "relu/after_gemm", "relu_after_gemm.onnx")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Main # Main
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -361,4 +415,10 @@ if __name__ == "__main__":
avgpool_include_pad() avgpool_include_pad()
maxpool_after_conv() maxpool_after_conv()
print("\nGenerating Relu tests:")
relu_basic()
relu_4d()
relu_after_conv()
relu_after_gemm()
print("\nDone.") print("\nDone.")

Binary file not shown.

Binary file not shown.