add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
This commit is contained in:
@@ -39,16 +39,16 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||
pm.addPass(createBufferizePimPass());
|
||||
pm.addPass(createPimBufferizationPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim bufferized"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createConstantFoldingPass());
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createMaterializeConstantsPass());
|
||||
pm.addPass(createVerificationPass());
|
||||
pm.addPass(createPimMaterializeConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
|
||||
@@ -3,10 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/Relu.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
@@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
enum class MapOperations {
|
||||
None,
|
||||
ONNXSoftmaxOp,
|
||||
ONNXReluOp,
|
||||
ONNXLeakyReluOp,
|
||||
ONNXExpOp
|
||||
};
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
@@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <size_t NumInputs, typename BodyFn>
|
||||
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
|
||||
@@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td"
|
||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def onnxToArithConstantOp : Pat<
|
||||
def onnxToArithConstant : Pat<
|
||||
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
||||
(Arith_ConstantOp $value)
|
||||
>;
|
||||
@@ -19,7 +19,7 @@ def IsRank2Result: Constraint<
|
||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||
"Result is rank 2">;
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
def matMulAddToGemm : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
@@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat<
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
def matMulToGemm : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
@@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat<
|
||||
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
|
||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
|
||||
// ONNXConvOp with a bias.
|
||||
def convAddToConvWithBiasPatternLeft : Pat<
|
||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
|
||||
def convAddToConvWithBiasLeft : Pat<
|
||||
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
|
||||
def convAddToConvWithBiasPatternRight : Pat<
|
||||
def convAddToConvWithBiasRight : Pat<
|
||||
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
@@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat<
|
||||
|
||||
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
||||
|
||||
def removeLRNPattern : Pat<
|
||||
def removeLRN : Pat<
|
||||
(ONNXLRNOp $A, $_, $_, $_, $_),
|
||||
(replaceWithOperationOfValue $A)
|
||||
>;
|
||||
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
|
||||
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
||||
"Two tensors have the same static shape">;
|
||||
|
||||
def removeFlattenSameShapePattern : Pat<
|
||||
def removeFlattenSameShape : Pat<
|
||||
(ONNXFlattenOp:$flattenOp $A, $axis),
|
||||
(replaceWithOperationOfValue $A),
|
||||
[(HaveSameStaticShape $flattenOp, $A)]
|
||||
>; // Add closing parenthesis here
|
||||
>;
|
||||
|
||||
#endif // ONNX_TO_SPATIAL
|
||||
|
||||
@@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
RewritePatternSet mergeActivationPatterns(ctx);
|
||||
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
|
||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
@@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
||||
target.addIllegalOp<ONNXAveragePoolOp>();
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXReluOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
|
||||
@@ -5,16 +5,18 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
GemvToSpatialCompute(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
@@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
@@ -257,7 +257,7 @@ struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||
}
|
||||
|
||||
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
Location loc = reluOp.getLoc();
|
||||
Type resultType = reluOp.getResult().getType();
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
||||
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
||||
});
|
||||
rewriter.replaceOp(reluOp, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -9,8 +9,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
Concat(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||
ONNXConcatOpAdaptor adaptor,
|
||||
@@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<Concat>(ctx);
|
||||
}
|
||||
|
||||
|
||||
@@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||
void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -9,41 +9,46 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def onnxToPimTransposeOp : Pat<
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVMMOp : Pat<
|
||||
def spatToPimVMM : Pat<
|
||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimMVMOp : Pat<
|
||||
def spatToPimMVM : Pat<
|
||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimMVMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVAddOp : Pat<
|
||||
def spatToPimVVAdd : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVVAddOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMulOp : Pat<
|
||||
def spatToPimVVMul : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMaxOp : Pat<
|
||||
def spatToPimVVMax : Pat<
|
||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||
(PimVVMaxOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVRelu : Pat<
|
||||
(SpatReluOp:$srcOpRes $input),
|
||||
(PimVReluOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -363,29 +363,6 @@ def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
|
||||
let summary = "Reduce all elements to a single value";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
||||
let summary = "Average all elements into a single value";
|
||||
|
||||
|
||||
@@ -97,8 +97,7 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -123,7 +122,7 @@ struct TransposeOpBufferizeInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
||||
struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface, PimVMMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -160,7 +159,7 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
||||
}
|
||||
};
|
||||
|
||||
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
|
||||
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -186,8 +185,7 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -225,17 +223,56 @@ struct BinaryDstOpBufferizeInterface
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||
const AnalysisState& state,
|
||||
ArrayRef<OpOperand*> opOperands) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto unaryOp = cast<OpTy>(op);
|
||||
|
||||
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
||||
|
||||
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
||||
|
||||
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
||||
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
|
||||
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -16,5 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
|
||||
@@ -105,6 +105,6 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -486,8 +486,6 @@ struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, Spa
|
||||
|
||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||
|
||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
@@ -496,7 +494,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
|
||||
@@ -13,13 +13,13 @@ std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createVerificationPass();
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
|
||||
|
||||
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -126,6 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -189,6 +189,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimAccelerator"
|
||||
|
||||
@@ -69,13 +70,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
||||
|
||||
void PimAccelerator::registerPasses(int optLevel) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
||||
TOTAL_COMPILE_PHASE = 8;
|
||||
registerPass(createONNXToSpatialPass);
|
||||
registerPass(createSpatialToGraphvizPass);
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createBufferizePimPass);
|
||||
registerPass(createConstantFoldingPass);
|
||||
registerPass(createMaterializeConstantsPass);
|
||||
registerPass(createVerificationPass);
|
||||
registerPass(createPimBufferizationPass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimMaterializeConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,15 @@ python3 validation/operations/gen_tests.py
|
||||
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
|
||||
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
|
||||
|
||||
## Relu
|
||||
|
||||
| Test | Directory | Input | Output | Notes |
|
||||
|------|-----------|-------|--------|-------|
|
||||
| Basic | `relu/basic` | [4,8] | [4,8] | Standalone 2D Relu |
|
||||
| 4D | `relu/4d` | [2,3,4,4] | [2,3,4,4] | Standalone NCHW Relu |
|
||||
| After Conv | `relu/after_conv` | [1,3,5,5] | [1,2,3,3] | Conv 3x3 + bias, then Relu |
|
||||
| After Gemm | `relu/after_gemm` | [4,64] | [4,32] | Gemm + bias, then Relu |
|
||||
|
||||
## Gemm
|
||||
|
||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
|
||||
"""Generate ONNX test models for validating GEMM, Conv, Pooling, and Relu implementations."""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
@@ -327,6 +327,60 @@ def maxpool_after_conv():
|
||||
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Relu tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def relu_basic():
|
||||
"""Standalone Relu on a simple 2D tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
|
||||
node = helper.make_node("Relu", ["X"], ["Y"])
|
||||
graph = helper.make_graph([node], "relu_basic", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/basic", "relu_basic.onnx")
|
||||
|
||||
|
||||
def relu_4d():
|
||||
"""Standalone Relu on an NCHW tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 4])
|
||||
node = helper.make_node("Relu", ["X"], ["Y"])
|
||||
graph = helper.make_graph([node], "relu_4d", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/4d", "relu_4d.onnx")
|
||||
|
||||
|
||||
def relu_after_conv():
|
||||
"""Conv followed by Relu."""
|
||||
rng = np.random.default_rng(60)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
|
||||
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
|
||||
conv = helper.make_node("Conv", ["X", "W", "B"], ["C"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
relu = helper.make_node("Relu", ["C"], ["Y"])
|
||||
graph = helper.make_graph([conv, relu], "relu_after_conv", [X], [Y], initializer=[W, B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/after_conv", "relu_after_conv.onnx")
|
||||
|
||||
|
||||
def relu_after_gemm():
|
||||
"""Gemm followed by Relu."""
|
||||
B, K, N = 4, 64, 32
|
||||
rng = np.random.default_rng(61)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"])
|
||||
relu = helper.make_node("Relu", ["G"], ["Y"])
|
||||
graph = helper.make_graph([gemm, relu], "relu_after_gemm", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/after_gemm", "relu_after_gemm.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -361,4 +415,10 @@ if __name__ == "__main__":
|
||||
avgpool_include_pad()
|
||||
maxpool_after_conv()
|
||||
|
||||
print("\nGenerating Relu tests:")
|
||||
relu_basic()
|
||||
relu_4d()
|
||||
relu_after_conv()
|
||||
relu_after_gemm()
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
BIN
validation/operations/relu/4d/relu_4d.onnx
Normal file
BIN
validation/operations/relu/4d/relu_4d.onnx
Normal file
Binary file not shown.
BIN
validation/operations/relu/after_conv/relu_after_conv.onnx
Normal file
BIN
validation/operations/relu/after_conv/relu_after_conv.onnx
Normal file
Binary file not shown.
BIN
validation/operations/relu/after_gemm/relu_after_gemm.onnx
Normal file
BIN
validation/operations/relu/after_gemm/relu_after_gemm.onnx
Normal file
Binary file not shown.
BIN
validation/operations/relu/basic/relu_basic.onnx
Normal file
BIN
validation/operations/relu/basic/relu_basic.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user