Refactor ONNXToSpatial Common and diagnostics

This commit is contained in:
NiccoloN
2026-05-04 13:42:43 +02:00
parent b6ba1e4fea
commit f789954ad7
26 changed files with 686 additions and 486 deletions
@@ -7,11 +7,10 @@
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -370,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
@@ -391,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
@@ -431,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
@@ -5,7 +5,8 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,13 +16,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
@@ -8,10 +8,9 @@
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -136,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
@@ -175,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
@@ -199,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
else if (!isVectorShape(getTensorShape(c))) {
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
return failure();
}
}
auto gemvOp = ONNXGemmOp::create(rewriter,
@@ -258,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv
@@ -341,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
});
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp.getResult(0));
partialResults.push_back(computeOp->getResult(0));
}
if (hasC) {
@@ -388,14 +429,28 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
if (numOutRows <= 1)
@@ -438,7 +493,14 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
});
cType = cast<RankedTensorType>(c.getType());
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
return failure();
@@ -5,7 +5,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"