add PIM accelerator
This commit is contained in:
79
src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td
Normal file
79
src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td
Normal file
@@ -0,0 +1,79 @@
|
||||
#ifndef ONNX_TO_SPATIAL
|
||||
#define ONNX_TO_SPATIAL
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/Dialect/Tensor/IR/TensorOps.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithOps.td"
|
||||
include "src/Dialect/ONNX/ONNX.td"
|
||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def onnxToArithConstantOp : Pat<
|
||||
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
||||
(Arith_ConstantOp $value)
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
|
||||
// ONNXConvOp with a bias.
|
||||
def convAddToConvWithBiasPatternLeft : Pat<
|
||||
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
|
||||
def convAddToConvWithBiasPatternRight : Pat<
|
||||
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation to ignore (i.e. remove)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
||||
|
||||
def removeLRNPattern : Pat<
|
||||
(ONNXLRNOp $A, $_, $_, $_, $_),
|
||||
(replaceWithOperationOfValue $A)
|
||||
>;
|
||||
|
||||
def HaveSameStaticShape: Constraint<
|
||||
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
||||
"Two tensors have the same static shape">;
|
||||
|
||||
def removeFlattenSameShapePattern : Pat<
|
||||
(ONNXFlattenOp:$flattenOp $A, $axis),
|
||||
(replaceWithOperationOfValue $A),
|
||||
[(HaveSameStaticShape $flattenOp, $A)]
|
||||
>; // Add closing parenthesis here
|
||||
|
||||
#endif // ONNX_TO_SPATIAL
|
||||
Reference in New Issue
Block a user