Refactor ONNXToSpatial Common and diagnostics
This commit is contained in:
@@ -8,10 +8,9 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -136,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
if (gemmOpAdaptor.getTransA()) {
|
||||
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
if (!aType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
@@ -175,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
if (!cType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||
return failure();
|
||||
}
|
||||
if (cType.getRank() != 2) {
|
||||
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||
return failure();
|
||||
}
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
@@ -199,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
else if (!isVectorShape(getTensorShape(c))) {
|
||||
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
@@ -258,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
if (!cType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||
return failure();
|
||||
}
|
||||
if (cType.getRank() != 2) {
|
||||
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
if (!aType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||
return failure();
|
||||
}
|
||||
if (!bType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||
// Not a gemv
|
||||
@@ -341,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
partialResults.push_back(computeOp->getResult(0));
|
||||
}
|
||||
|
||||
if (hasC) {
|
||||
@@ -388,14 +429,28 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
if (gemmOpAdaptor.getTransA()) {
|
||||
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
||||
if (!aType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||
return failure();
|
||||
}
|
||||
if (!bType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
if (numOutRows <= 1)
|
||||
@@ -438,7 +493,14 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
});
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
if (!cType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||
return failure();
|
||||
}
|
||||
if (cType.getRank() != 2) {
|
||||
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||
return failure();
|
||||
}
|
||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user