add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
This commit is contained in:
@@ -39,16 +39,16 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||||
pm.addPass(createBufferizePimPass());
|
pm.addPass(createPimBufferizationPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Pim bufferized"));
|
pm.addPass(createMessagePass("Pim bufferized"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createConstantFoldingPass());
|
pm.addPass(createPimConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim constants folded"));
|
pm.addPass(createMessagePass("Pim constants folded"));
|
||||||
pm.addPass(createMaterializeConstantsPass());
|
pm.addPass(createPimMaterializeConstantsPass());
|
||||||
pm.addPass(createVerificationPass());
|
pm.addPass(createPimVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim verified"));
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|||||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||||
|
|
||||||
add_pim_library(OMONNXToSpatial
|
add_pim_library(OMONNXToSpatial
|
||||||
Patterns/Math/Gemm.cpp
|
|
||||||
Patterns/Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
|
Patterns/Math/Gemm.cpp
|
||||||
Patterns/Math/MatMul.cpp
|
Patterns/Math/MatMul.cpp
|
||||||
Patterns/NN/Pool.cpp
|
Patterns/NN/Pool.cpp
|
||||||
|
Patterns/NN/Relu.cpp
|
||||||
Patterns/Tensor/Concat.cpp
|
Patterns/Tensor/Concat.cpp
|
||||||
Patterns/Tensor/Reshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <optional>
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) {
|
|||||||
using HSliceId = size_t;
|
using HSliceId = size_t;
|
||||||
using CoreId = size_t;
|
using CoreId = size_t;
|
||||||
|
|
||||||
enum class MapOperations {
|
|
||||||
None,
|
|
||||||
ONNXSoftmaxOp,
|
|
||||||
ONNXReluOp,
|
|
||||||
ONNXLeakyReluOp,
|
|
||||||
ONNXExpOp
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
constexpr C ceilIntegerDivide(A a, B b) {
|
constexpr C ceilIntegerDivide(A a, B b) {
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
@@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) {
|
|||||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
|
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <size_t NumInputs, typename BodyFn>
|
||||||
|
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
size_t axis,
|
size_t axis,
|
||||||
int64_t sliceSize,
|
int64_t sliceSize,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
def onnxToArithConstantOp : Pat<
|
def onnxToArithConstant : Pat<
|
||||||
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
||||||
(Arith_ConstantOp $value)
|
(Arith_ConstantOp $value)
|
||||||
>;
|
>;
|
||||||
@@ -19,7 +19,7 @@ def IsRank2Result: Constraint<
|
|||||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||||
"Result is rank 2">;
|
"Result is rank 2">;
|
||||||
|
|
||||||
def matMulAddToGemmPattern : Pat<
|
def matMulAddToGemm : Pat<
|
||||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||||
(ONNXGemmOp $A, $B, $C,
|
(ONNXGemmOp $A, $B, $C,
|
||||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
@@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat<
|
|||||||
[(IsRank2Result $matmulres)]
|
[(IsRank2Result $matmulres)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def matMulToGemmPattern : Pat<
|
def matMulToGemm : Pat<
|
||||||
(ONNXMatMulOp:$matmulres $A, $B),
|
(ONNXMatMulOp:$matmulres $A, $B),
|
||||||
(
|
(
|
||||||
ONNXGemmOp $A, $B,
|
ONNXGemmOp $A, $B,
|
||||||
@@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat<
|
|||||||
|
|
||||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||||
|
|
||||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
|
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
|
||||||
// ONNXConvOp with a bias.
|
def convAddToConvWithBiasLeft : Pat<
|
||||||
def convAddToConvWithBiasPatternLeft : Pat<
|
|
||||||
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
||||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def convAddToConvWithBiasPatternRight : Pat<
|
def convAddToConvWithBiasRight : Pat<
|
||||||
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
||||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||||
>;
|
>;
|
||||||
@@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat<
|
|||||||
|
|
||||||
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
||||||
|
|
||||||
def removeLRNPattern : Pat<
|
def removeLRN : Pat<
|
||||||
(ONNXLRNOp $A, $_, $_, $_, $_),
|
(ONNXLRNOp $A, $_, $_, $_, $_),
|
||||||
(replaceWithOperationOfValue $A)
|
(replaceWithOperationOfValue $A)
|
||||||
>;
|
>;
|
||||||
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
|
|||||||
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
||||||
"Two tensors have the same static shape">;
|
"Two tensors have the same static shape">;
|
||||||
|
|
||||||
def removeFlattenSameShapePattern : Pat<
|
def removeFlattenSameShape : Pat<
|
||||||
(ONNXFlattenOp:$flattenOp $A, $axis),
|
(ONNXFlattenOp:$flattenOp $A, $axis),
|
||||||
(replaceWithOperationOfValue $A),
|
(replaceWithOperationOfValue $A),
|
||||||
[(HaveSameStaticShape $flattenOp, $A)]
|
[(HaveSameStaticShape $flattenOp, $A)]
|
||||||
>; // Add closing parenthesis here
|
>;
|
||||||
|
|
||||||
#endif // ONNX_TO_SPATIAL
|
#endif // ONNX_TO_SPATIAL
|
||||||
|
|||||||
@@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
RewritePatternSet mergeActivationPatterns(ctx);
|
RewritePatternSet mergeActivationPatterns(ctx);
|
||||||
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
|
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
|
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
|
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
||||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
||||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
||||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
||||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||||
@@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||||
target.addIllegalOp<ONNXGemmOp>();
|
target.addIllegalOp<ONNXGemmOp>();
|
||||||
target.addIllegalOp<ONNXConvOp>();
|
target.addIllegalOp<ONNXConvOp>();
|
||||||
target.addIllegalOp<ONNXLRNOp>();
|
|
||||||
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
||||||
target.addIllegalOp<ONNXAveragePoolOp>();
|
target.addIllegalOp<ONNXAveragePoolOp>();
|
||||||
target.addIllegalOp<ONNXConcatOp>();
|
target.addIllegalOp<ONNXReluOp>();
|
||||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
target.addIllegalOp<ONNXConcatOp>();
|
||||||
target.addIllegalOp<ONNXReshapeOp>();
|
target.addIllegalOp<ONNXReshapeOp>();
|
||||||
|
target.addIllegalOp<ONNXLRNOp>();
|
||||||
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
patterns.add<removeLRNPattern>(ctx);
|
patterns.add<removeLRN>(ctx);
|
||||||
|
|
||||||
populateConvOpPatterns(patterns, ctx);
|
populateGemmPatterns(patterns, ctx);
|
||||||
populatePoolTilingPattern(patterns, ctx);
|
populateConvPatterns(patterns, ctx);
|
||||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
populatePoolPatterns(patterns, ctx);
|
||||||
populateReshapeConversionPattern(patterns, ctx);
|
populateReluPatterns(patterns, ctx);
|
||||||
|
populateConcatPatterns(patterns, ctx);
|
||||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
populateReshapePatterns(patterns, ctx);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|||||||
@@ -5,16 +5,18 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||||
GemvToSpatialCompute(MLIRContext* ctx)
|
using OpConversionPattern::OpConversionPattern;
|
||||||
: OpConversionPattern(ctx, 1) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
@@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<GemmToManyGemv>(ctx);
|
patterns.insert<GemmToManyGemv>(ctx);
|
||||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||||
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||||
|
Location loc = reluOp.getLoc();
|
||||||
|
Type resultType = reluOp.getResult().getType();
|
||||||
|
constexpr size_t numInputs = 1;
|
||||||
|
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
||||||
|
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(reluOp, computeOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -9,8 +9,7 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||||
Concat(MLIRContext* ctx)
|
using OpConversionPattern::OpConversionPattern;
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||||
ONNXConcatOpAdaptor adaptor,
|
ONNXConcatOpAdaptor adaptor,
|
||||||
@@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<Concat>(ctx);
|
patterns.insert<Concat>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -9,41 +9,46 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
def onnxToPimTransposeOp : Pat<
|
def onnxToPimTranspose : Pat<
|
||||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||||
(PimTransposeOp $data, $perms,
|
(PimTransposeOp $data, $perms,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVMMOp : Pat<
|
def spatToPimVMM : Pat<
|
||||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||||
(PimVMMOp $weightIndex, $vector,
|
(PimVMMOp $weightIndex, $vector,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimMVMOp : Pat<
|
def spatToPimMVM : Pat<
|
||||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
||||||
(PimMVMOp $weightIndex, $vector,
|
(PimMVMOp $weightIndex, $vector,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVVAddOp : Pat<
|
def spatToPimVVAdd : Pat<
|
||||||
(SpatVAddOp:$srcOpRes $a, $b),
|
(SpatVAddOp:$srcOpRes $a, $b),
|
||||||
(PimVVAddOp $a, $b,
|
(PimVVAddOp $a, $b,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVVMulOp : Pat<
|
def spatToPimVVMul : Pat<
|
||||||
(SpatVMulOp:$srcOpRes $a, $b),
|
(SpatVMulOp:$srcOpRes $a, $b),
|
||||||
(PimVVMulOp $a, $b,
|
(PimVVMulOp $a, $b,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVVMaxOp : Pat<
|
def spatToPimVVMax : Pat<
|
||||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||||
(PimVVMaxOp $a, $b,
|
(PimVVMaxOp $a, $b,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
def spatToPimVRelu : Pat<
|
||||||
|
(SpatReluOp:$srcOpRes $input),
|
||||||
|
(PimVReluOp $input,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
#endif // SPATIAL_TO_PIM
|
||||||
|
|||||||
@@ -363,29 +363,6 @@ def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Reduce all elements to a single value";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$input,
|
|
||||||
PimTensor:$outputBuffer
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutputBufferMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Average all elements into a single value";
|
let summary = "Average all elements into a single value";
|
||||||
|
|
||||||
|
|||||||
@@ -97,8 +97,7 @@ struct MemCopyDevToHostOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TransposeOpBufferizeInterface
|
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||||
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
}
|
}
|
||||||
@@ -123,7 +122,7 @@ struct TransposeOpBufferizeInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface, PimVMMOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
}
|
}
|
||||||
@@ -160,7 +159,7 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
|
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
}
|
}
|
||||||
@@ -186,8 +185,7 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
struct BinaryDstOpBufferizeInterface
|
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||||
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
}
|
}
|
||||||
@@ -225,17 +223,56 @@ struct BinaryDstOpBufferizeInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||||
|
const AnalysisState& state,
|
||||||
|
ArrayRef<OpOperand*> opOperands) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto unaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
|
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
|
||||||
|
if (failed(inputOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
||||||
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
|
|
||||||
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
|
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||||
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
|
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||||
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
|
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
||||||
|
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
||||||
|
|
||||||
|
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
||||||
|
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||||
|
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
|
||||||
|
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,5 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
|
|||||||
(returnType $dst))
|
(returnType $dst))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|
||||||
#endif // PIM_BUFFERIZATION
|
#endif // PIM_BUFFERIZATION
|
||||||
|
|||||||
@@ -105,6 +105,6 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
|
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -486,8 +486,6 @@ struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, Spa
|
|||||||
|
|
||||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||||
|
|
||||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
|
||||||
|
|
||||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||||
@@ -496,7 +494,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|||||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
|
||||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
|
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
|
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createVerificationPass();
|
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -126,6 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -189,6 +189,6 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
|
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||||
|
#include "src/Compiler/CompilerUtils.hpp"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimAccelerator"
|
#define DEBUG_TYPE "PimAccelerator"
|
||||||
|
|
||||||
@@ -69,13 +70,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
|||||||
|
|
||||||
void PimAccelerator::registerPasses(int optLevel) const {
|
void PimAccelerator::registerPasses(int optLevel) const {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
||||||
|
TOTAL_COMPILE_PHASE = 8;
|
||||||
registerPass(createONNXToSpatialPass);
|
registerPass(createONNXToSpatialPass);
|
||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createBufferizePimPass);
|
registerPass(createPimBufferizationPass);
|
||||||
registerPass(createConstantFoldingPass);
|
registerPass(createPimConstantFoldingPass);
|
||||||
registerPass(createMaterializeConstantsPass);
|
registerPass(createPimMaterializeConstantsPass);
|
||||||
registerPass(createVerificationPass);
|
registerPass(createPimVerificationPass);
|
||||||
registerPass(createEmitPimJsonPass);
|
registerPass(createEmitPimJsonPass);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,15 @@ python3 validation/operations/gen_tests.py
|
|||||||
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
|
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
|
||||||
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
|
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
|
||||||
|
|
||||||
|
## Relu
|
||||||
|
|
||||||
|
| Test | Directory | Input | Output | Notes |
|
||||||
|
|------|-----------|-------|--------|-------|
|
||||||
|
| Basic | `relu/basic` | [4,8] | [4,8] | Standalone 2D Relu |
|
||||||
|
| 4D | `relu/4d` | [2,3,4,4] | [2,3,4,4] | Standalone NCHW Relu |
|
||||||
|
| After Conv | `relu/after_conv` | [1,3,5,5] | [1,2,3,3] | Conv 3x3 + bias, then Relu |
|
||||||
|
| After Gemm | `relu/after_gemm` | [4,64] | [4,32] | Gemm + bias, then Relu |
|
||||||
|
|
||||||
## Gemm
|
## Gemm
|
||||||
|
|
||||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
|
"""Generate ONNX test models for validating GEMM, Conv, Pooling, and Relu implementations."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
@@ -327,6 +327,60 @@ def maxpool_after_conv():
|
|||||||
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Relu tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def relu_basic():
|
||||||
|
"""Standalone Relu on a simple 2D tensor."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
|
||||||
|
node = helper.make_node("Relu", ["X"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "relu_basic", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "relu/basic", "relu_basic.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def relu_4d():
|
||||||
|
"""Standalone Relu on an NCHW tensor."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 4])
|
||||||
|
node = helper.make_node("Relu", ["X"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "relu_4d", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "relu/4d", "relu_4d.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def relu_after_conv():
|
||||||
|
"""Conv followed by Relu."""
|
||||||
|
rng = np.random.default_rng(60)
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
|
||||||
|
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
|
||||||
|
conv = helper.make_node("Conv", ["X", "W", "B"], ["C"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
relu = helper.make_node("Relu", ["C"], ["Y"])
|
||||||
|
graph = helper.make_graph([conv, relu], "relu_after_conv", [X], [Y], initializer=[W, B])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "relu/after_conv", "relu_after_conv.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def relu_after_gemm():
|
||||||
|
"""Gemm followed by Relu."""
|
||||||
|
B, K, N = 4, 64, 32
|
||||||
|
rng = np.random.default_rng(61)
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"])
|
||||||
|
relu = helper.make_node("Relu", ["G"], ["Y"])
|
||||||
|
graph = helper.make_graph([gemm, relu], "relu_after_gemm", [A], [Y], initializer=[W, C])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "relu/after_gemm", "relu_after_gemm.onnx")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Main
|
# Main
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -361,4 +415,10 @@ if __name__ == "__main__":
|
|||||||
avgpool_include_pad()
|
avgpool_include_pad()
|
||||||
maxpool_after_conv()
|
maxpool_after_conv()
|
||||||
|
|
||||||
|
print("\nGenerating Relu tests:")
|
||||||
|
relu_basic()
|
||||||
|
relu_4d()
|
||||||
|
relu_after_conv()
|
||||||
|
relu_after_gemm()
|
||||||
|
|
||||||
print("\nDone.")
|
print("\nDone.")
|
||||||
|
|||||||
BIN
validation/operations/relu/4d/relu_4d.onnx
Normal file
BIN
validation/operations/relu/4d/relu_4d.onnx
Normal file
Binary file not shown.
BIN
validation/operations/relu/after_conv/relu_after_conv.onnx
Normal file
BIN
validation/operations/relu/after_conv/relu_after_conv.onnx
Normal file
Binary file not shown.
BIN
validation/operations/relu/after_gemm/relu_after_gemm.onnx
Normal file
BIN
validation/operations/relu/after_gemm/relu_after_gemm.onnx
Normal file
Binary file not shown.
BIN
validation/operations/relu/basic/relu_basic.onnx
Normal file
BIN
validation/operations/relu/basic/relu_basic.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user