add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
This commit is contained in:
@@ -3,10 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/Relu.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
@@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
enum class MapOperations {
|
||||
None,
|
||||
ONNXSoftmaxOp,
|
||||
ONNXReluOp,
|
||||
ONNXLeakyReluOp,
|
||||
ONNXExpOp
|
||||
};
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
@@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <size_t NumInputs, typename BodyFn>
|
||||
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
|
||||
@@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td"
|
||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def onnxToArithConstantOp : Pat<
|
||||
def onnxToArithConstant : Pat<
|
||||
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
||||
(Arith_ConstantOp $value)
|
||||
>;
|
||||
@@ -19,7 +19,7 @@ def IsRank2Result: Constraint<
|
||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||
"Result is rank 2">;
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
def matMulAddToGemm : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
@@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat<
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
def matMulToGemm : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
@@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat<
|
||||
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
|
||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
|
||||
// ONNXConvOp with a bias.
|
||||
def convAddToConvWithBiasPatternLeft : Pat<
|
||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
|
||||
def convAddToConvWithBiasLeft : Pat<
|
||||
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
|
||||
def convAddToConvWithBiasPatternRight : Pat<
|
||||
def convAddToConvWithBiasRight : Pat<
|
||||
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
@@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat<
|
||||
|
||||
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
||||
|
||||
def removeLRNPattern : Pat<
|
||||
def removeLRN : Pat<
|
||||
(ONNXLRNOp $A, $_, $_, $_, $_),
|
||||
(replaceWithOperationOfValue $A)
|
||||
>;
|
||||
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
|
||||
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
||||
"Two tensors have the same static shape">;
|
||||
|
||||
def removeFlattenSameShapePattern : Pat<
|
||||
def removeFlattenSameShape : Pat<
|
||||
(ONNXFlattenOp:$flattenOp $A, $axis),
|
||||
(replaceWithOperationOfValue $A),
|
||||
[(HaveSameStaticShape $flattenOp, $A)]
|
||||
>; // Add closing parenthesis here
|
||||
>;
|
||||
|
||||
#endif // ONNX_TO_SPATIAL
|
||||
|
||||
@@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
RewritePatternSet mergeActivationPatterns(ctx);
|
||||
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
|
||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
@@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
||||
target.addIllegalOp<ONNXAveragePoolOp>();
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXReluOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
|
||||
@@ -5,16 +5,18 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
GemvToSpatialCompute(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
@@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
@@ -257,7 +257,7 @@ struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||
}
|
||||
|
||||
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
Location loc = reluOp.getLoc();
|
||||
Type resultType = reluOp.getResult().getType();
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
||||
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
||||
});
|
||||
rewriter.replaceOp(reluOp, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -9,8 +9,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
Concat(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||
ONNXConcatOpAdaptor adaptor,
|
||||
@@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<Concat>(ctx);
|
||||
}
|
||||
|
||||
|
||||
@@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||
void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -9,41 +9,46 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def onnxToPimTransposeOp : Pat<
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVMMOp : Pat<
|
||||
def spatToPimVMM : Pat<
|
||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimMVMOp : Pat<
|
||||
def spatToPimMVM : Pat<
|
||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimMVMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVAddOp : Pat<
|
||||
def spatToPimVVAdd : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVVAddOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMulOp : Pat<
|
||||
def spatToPimVVMul : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMaxOp : Pat<
|
||||
def spatToPimVVMax : Pat<
|
||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||
(PimVVMaxOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVRelu : Pat<
|
||||
(SpatReluOp:$srcOpRes $input),
|
||||
(PimVReluOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
Reference in New Issue
Block a user