SpatialSubOp
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-06-05 17:12:16 +02:00
parent 08870de1a6
commit 90c4339808
12 changed files with 117 additions and 0 deletions
@@ -124,6 +124,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXSubOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
target.addIllegalOp<ONNXGemmOp>();
@@ -189,6 +189,7 @@ struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
patterns.add<BinaryElementwiseToSpatialCompute<ONNXSubOp, spatial::SpatVSubOp>>(ctx);
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
patterns.add<DivToSpatialCompute>(ctx);
}
@@ -27,6 +27,12 @@ def spatToPimVVAdd : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVSub : Pat<
(SpatVSubOp:$srcOpRes $a, $b),
(PimVVSubOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMul : Pat<
(SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b,
+19
View File
@@ -257,6 +257,25 @@ def SpatVAddOp : SpatOp<"vadd", []> {
}];
}
def SpatVSubOp : SpatOp<"vsub", []> {
let summary = "Element-wise subtraction between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVMulOp : SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
@@ -254,6 +254,12 @@ LogicalResult SpatVAddOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVSubOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();