c77ffa9c56
support for tensors of index values
603 lines
24 KiB
C++
603 lines
24 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static FailureOr<Value>
|
|
materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) {
|
|
if (factor == 1.0f)
|
|
return value;
|
|
|
|
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
|
if (!constantOp)
|
|
return failure();
|
|
|
|
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
|
if (!denseAttr)
|
|
return failure();
|
|
|
|
SmallVector<APFloat> scaledValues;
|
|
scaledValues.reserve(denseAttr.getNumElements());
|
|
APFloat scale(factor);
|
|
bool hadFailure = false;
|
|
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
|
APFloat scaledValue(originalValue);
|
|
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
|
hadFailure = true;
|
|
scaledValues.push_back(std::move(scaledValue));
|
|
}
|
|
if (hadFailure)
|
|
return failure();
|
|
|
|
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
|
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
|
}
|
|
|
|
static Value transposeForSpatial(Value value,
|
|
RankedTensorType resultType,
|
|
ArrayRef<int64_t> permutation,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (isHostFoldableValue(value))
|
|
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
|
|
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static Value
|
|
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
if (isHostFoldableValue(value))
|
|
return tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
resultType,
|
|
value,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1}
|
|
});
|
|
|
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
resultType,
|
|
input,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1}
|
|
});
|
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const override;
|
|
};
|
|
|
|
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const override;
|
|
};
|
|
|
|
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const override;
|
|
};
|
|
|
|
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
|
RankedTensorType matrixType,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t numRows = matrixType.getDimSize(0);
|
|
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
|
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
|
|
|
if (isHostFoldableValue(matrix)) {
|
|
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
|
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
|
}
|
|
|
|
auto buildRowSlices = [&](Value matrixArg) {
|
|
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
|
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
|
};
|
|
|
|
auto cloneBatchInputChainIntoSliceCompute =
|
|
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
|
auto sliceCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
|
Value transformedMatrix = input;
|
|
if (!chainOps.empty()) {
|
|
IRMapping mapper;
|
|
mapper.map(rootValue, input);
|
|
for (Operation* chainOp : chainOps)
|
|
rewriter.clone(*chainOp, mapper);
|
|
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
|
}
|
|
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
|
});
|
|
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
|
return rowSlices;
|
|
};
|
|
|
|
SmallVector<Operation*> chainOps;
|
|
Value rootValue = matrix;
|
|
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
|
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
|
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
|
return cloneBatchInputChainIntoSliceCompute(
|
|
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
|
}
|
|
|
|
if (definingOp->getNumOperands() != 1)
|
|
break;
|
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
break;
|
|
|
|
chainOps.push_back(definingOp);
|
|
rootValue = definingOp->getOperand(0);
|
|
}
|
|
|
|
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
|
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const {
|
|
Location loc = gemmOp.getLoc();
|
|
Value a = gemmOpAdaptor.getA();
|
|
Value b = gemmOpAdaptor.getB();
|
|
Value c = gemmOpAdaptor.getC();
|
|
|
|
if (gemmOpAdaptor.getTransA()) {
|
|
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
|
|
return failure();
|
|
}
|
|
|
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
|
|
auto aType = cast<RankedTensorType>(a.getType());
|
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
|
if (!aType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
|
return failure();
|
|
}
|
|
if (!outType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
|
return failure();
|
|
}
|
|
|
|
const int64_t numOutRows = aType.getDimSize(0);
|
|
|
|
// Only decompose when there are multiple rows to split
|
|
if (numOutRows <= 1)
|
|
return failure();
|
|
|
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
|
if (failed(scaledB))
|
|
return failure();
|
|
b = *scaledB;
|
|
|
|
RankedTensorType cType = nullptr;
|
|
bool cHasNumOutRows = false;
|
|
if (hasC) {
|
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
|
if (failed(scaledC))
|
|
return failure();
|
|
c = *scaledC;
|
|
cType = cast<RankedTensorType>(c.getType());
|
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
|
if (cType.getRank() == 1) {
|
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
|
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
|
cType = expandedType;
|
|
}
|
|
if (!cType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
|
return failure();
|
|
}
|
|
if (cType.getRank() != 2) {
|
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
|
return failure();
|
|
}
|
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
|
}
|
|
|
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
|
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
|
SmallVector<Value> cSlices;
|
|
if (hasC && cHasNumOutRows)
|
|
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
|
|
|
SmallVector<Value> gemvOps;
|
|
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
|
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
|
Value cSlice = c;
|
|
if (hasC) {
|
|
if (cHasNumOutRows)
|
|
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
|
else if (!isVectorShape(getTensorShape(c))) {
|
|
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
auto gemvOp = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
outRowType,
|
|
aSlices[static_cast<size_t>(rowIdx)],
|
|
b,
|
|
cSlice,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
gemmOp.getTransAAttr(),
|
|
gemmOp.getTransBAttr());
|
|
gemvOps.push_back(gemvOp.getY());
|
|
}
|
|
|
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
|
});
|
|
|
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const {
|
|
Location gemmLoc = gemmOp.getLoc();
|
|
Value a = gemmOpAdaptor.getA();
|
|
Value b = gemmOpAdaptor.getB();
|
|
Value c = gemmOpAdaptor.getC();
|
|
Value out = gemmOp.getY();
|
|
|
|
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
|
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
|
bool transA = gemmOpAdaptor.getTransA();
|
|
bool transB = gemmOpAdaptor.getTransB();
|
|
|
|
auto aType = cast<RankedTensorType>(a.getType());
|
|
auto bType = cast<RankedTensorType>(b.getType());
|
|
auto outType = cast<RankedTensorType>(out.getType());
|
|
|
|
RankedTensorType cType = nullptr;
|
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
if (hasC) {
|
|
cType = cast<RankedTensorType>(c.getType());
|
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
|
if (cType.getRank() == 1) {
|
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
|
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
|
cType = expandedType;
|
|
}
|
|
if (!cType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
|
return failure();
|
|
}
|
|
if (cType.getRank() != 2) {
|
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
if (!aType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
|
return failure();
|
|
}
|
|
if (!bType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
|
return failure();
|
|
}
|
|
if (!outType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
|
return failure();
|
|
}
|
|
|
|
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
|
// Not a gemv
|
|
return failure();
|
|
|
|
if (transA) {
|
|
auto aShape = aType.getShape();
|
|
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
|
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
|
aType = cast<RankedTensorType>(a.getType());
|
|
}
|
|
if (transB) {
|
|
auto bShape = bType.getShape();
|
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
|
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
|
bType = cast<RankedTensorType>(b.getType());
|
|
}
|
|
|
|
if (alpha != 1.0f) {
|
|
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
|
if (failed(scaledB))
|
|
return failure();
|
|
b = *scaledB;
|
|
bType = cast<RankedTensorType>(b.getType());
|
|
alpha = 1.0f;
|
|
}
|
|
if (hasC && beta != 1.0f) {
|
|
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
|
if (failed(scaledC))
|
|
return failure();
|
|
c = *scaledC;
|
|
cType = cast<RankedTensorType>(c.getType());
|
|
beta = 1.0f;
|
|
}
|
|
|
|
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
|
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
|
auto bNumVSlices = aNumHSlices;
|
|
auto cNumHSlices = bNumHSlices;
|
|
auto cLastHSliceSize = bLastHSliceSize;
|
|
auto outNumHSlices = cNumHSlices;
|
|
auto outLastHSliceSize = cLastHSliceSize;
|
|
|
|
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
|
|
|
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
|
|
|
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
|
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
|
|
|
SmallVector<Value> cHSlices;
|
|
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
|
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
|
if (hasC)
|
|
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
|
|
|
RankedTensorType outHSliceType =
|
|
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
|
RankedTensorType outLastHSliceType =
|
|
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
|
|
|
SmallVector<Value> outHSlices;
|
|
outHSlices.reserve(outNumHSlices);
|
|
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
|
RankedTensorType currOutHSliceType = outHSliceType;
|
|
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
|
currOutHSliceType = outLastHSliceType;
|
|
|
|
SmallVector<Value> partialResults;
|
|
partialResults.reserve(coresPerVSlice);
|
|
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
|
SmallVector<Value> weights;
|
|
weights.reserve(aHSlices[coreId].size());
|
|
|
|
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
|
|
|
auto computeOp =
|
|
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
|
SmallVector<Type> blockArgTypes;
|
|
SmallVector<Location> blockArgLocs;
|
|
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
|
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
|
for (Value weight : weights) {
|
|
blockArgTypes.push_back(weight.getType());
|
|
blockArgLocs.push_back(gemmLoc);
|
|
}
|
|
for (Value input : aHSlices[coreId]) {
|
|
blockArgTypes.push_back(input.getType());
|
|
blockArgLocs.push_back(gemmLoc);
|
|
}
|
|
Block* body =
|
|
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
rewriter.setInsertionPointToEnd(body);
|
|
|
|
SmallVector<Value> vmmOutputs;
|
|
vmmOutputs.reserve(aHSlices[coreId].size());
|
|
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
|
|
auto weightArg = computeOp.getWeightArgument(aHSliceId);
|
|
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
|
if (!weightArg || !inputArg)
|
|
return failure();
|
|
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
|
|
}
|
|
if (vmmOutputs.empty()) {
|
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
|
return failure();
|
|
}
|
|
|
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
|
|
partialResults.push_back(computeOp->getResult(0));
|
|
}
|
|
|
|
if (hasC) {
|
|
Value cHSlice = cHSlices[outSliceId];
|
|
partialResults.push_back(cHSlice);
|
|
}
|
|
|
|
auto reduceComputeOp =
|
|
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) {
|
|
SmallVector<Value> values(blockArgs.begin(), blockArgs.end());
|
|
Value outHSlice = sumTensors(values, rewriter);
|
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
|
});
|
|
|
|
outHSlices.push_back(reduceComputeOp.getResult(0));
|
|
}
|
|
|
|
auto concatComputeOp =
|
|
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
|
});
|
|
|
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const {
|
|
Location loc = gemmOp.getLoc();
|
|
Value a = gemmOpAdaptor.getA();
|
|
Value b = gemmOpAdaptor.getB();
|
|
Value c = gemmOpAdaptor.getC();
|
|
|
|
if (gemmOpAdaptor.getTransA()) {
|
|
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
|
|
return failure();
|
|
}
|
|
|
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
|
|
auto aType = cast<RankedTensorType>(a.getType());
|
|
auto bType = cast<RankedTensorType>(b.getType());
|
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
|
if (!aType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
|
return failure();
|
|
}
|
|
if (!bType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
|
return failure();
|
|
}
|
|
if (!outType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
|
return failure();
|
|
}
|
|
|
|
const int64_t numOutRows = aType.getDimSize(0);
|
|
if (numOutRows <= 1)
|
|
return failure();
|
|
|
|
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
|
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
|
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
|
return failure();
|
|
|
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
|
if (failed(scaledB))
|
|
return failure();
|
|
b = *scaledB;
|
|
bType = cast<RankedTensorType>(b.getType());
|
|
|
|
if (gemmOpAdaptor.getTransB()) {
|
|
auto bShape = bType.getShape();
|
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
|
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
|
bType = cast<RankedTensorType>(b.getType());
|
|
}
|
|
(void) bType;
|
|
|
|
Value sharedBias;
|
|
if (hasC) {
|
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
|
if (failed(scaledC))
|
|
return failure();
|
|
c = *scaledC;
|
|
auto cType = cast<RankedTensorType>(c.getType());
|
|
if (cType.getRank() == 1) {
|
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
|
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
|
cType = cast<RankedTensorType>(c.getType());
|
|
}
|
|
if (!cType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
|
return failure();
|
|
}
|
|
if (cType.getRank() != 2) {
|
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
|
return failure();
|
|
}
|
|
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
|
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
|
return failure();
|
|
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
|
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
|
sharedBias = c;
|
|
}
|
|
|
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
|
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
|
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
|
loc,
|
|
TypeRange {outType},
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
|
ValueRange {b},
|
|
ValueRange {a});
|
|
|
|
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
|
SmallVector<Location> blockArgLocs(4, loc);
|
|
Block* body =
|
|
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
rewriter.setInsertionPointToEnd(body);
|
|
|
|
auto lane = batchOp.getLaneArgument();
|
|
auto weight = batchOp.getWeightArgument(0);
|
|
auto packedInput = batchOp.getInputArgument(0);
|
|
auto packedOutput = batchOp.getOutputArgument(0);
|
|
if (!lane || !weight || !packedInput || !packedOutput)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value row =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
|
|
.getResult();
|
|
|
|
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
|
|
Value laneResult = vmmResult;
|
|
if (sharedBias)
|
|
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
|
|
|
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
|
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
|
tensor::ParallelInsertSliceOp::create(
|
|
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
|
|
rewriter.setInsertionPointAfter(batchOp);
|
|
|
|
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
|
return success();
|
|
}
|
|
|
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
|
patterns.insert<GemmToManyGemv>(ctx);
|
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|