This commit is contained in:
@@ -124,6 +124,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addIllegalOp<ONNXTransposeOp>();
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
target.addIllegalOp<ONNXSubOp>();
|
||||
target.addIllegalOp<ONNXDivOp>();
|
||||
target.addIllegalOp<ONNXMulOp>();
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
|
||||
@@ -189,6 +189,7 @@ struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
|
||||
|
||||
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXSubOp, spatial::SpatVSubOp>>(ctx);
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
|
||||
patterns.add<DivToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,12 @@ def spatToPimVVAdd : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVSub : Pat<
|
||||
(SpatVSubOp:$srcOpRes $a, $b),
|
||||
(PimVVSubOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMul : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
|
||||
@@ -257,6 +257,25 @@ def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVSubOp : SpatOp<"vsub", []> {
|
||||
let summary = "Element-wise subtraction between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMulOp : SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
|
||||
@@ -254,6 +254,12 @@ LogicalResult SpatVAddOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatVSubOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatVMaxOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user