Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp
NiccoloN 9e0d31af50
All checks were successful
Validate Operations / validate-operations (push) Successful in 18m22s
automatic code-reformat
2026-04-09 14:27:23 +02:00

330 lines
13 KiB
C++

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<Value>
materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) {
if (factor == 1.0f)
return value;
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
if (!denseAttr)
return failure();
SmallVector<APFloat> scaledValues;
scaledValues.reserve(denseAttr.getNumElements());
APFloat scale(factor);
bool hadFailure = false;
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
APFloat scaledValue(originalValue);
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
hadFailure = true;
scaledValues.push_back(std::move(scaledValue));
}
if (hadFailure)
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
const int64_t numOutRows = aType.getDimSize(0);
// Only decompose when there are multiple rows to split
if (numOutRows <= 1)
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
RankedTensorType cType = nullptr;
bool cHasNumOutRows = false;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows);
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c;
if (hasC) {
if (cHasNumOutRows) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
}
auto gemvOp = ONNXGemmOp::create(rewriter,
loc,
outRowType,
aSlice,
b,
cSlice,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
gemmOp.getTransAAttr(),
gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY());
}
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location gemmLoc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
Value out = gemmOp.getY();
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
float beta = gemmOpAdaptor.getBeta().convertToFloat();
bool transA = gemmOpAdaptor.getTransA();
bool transB = gemmOpAdaptor.getTransB();
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(out.getType());
RankedTensorType cType = nullptr;
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
gemmLoc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv
return failure();
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
}
if (transB) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType());
}
if (alpha != 1.0f) {
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
alpha = 1.0f;
}
if (hasC && beta != 1.0f) {
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
beta = 1.0f;
}
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices;
auto outLastHSliceSize = cLastHSliceSize;
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
SmallVector<Value> cHSlices;
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
if (hasC)
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
RankedTensorType outHSliceType =
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
RankedTensorType outLastHSliceType =
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
SmallVector<Value> outHSlices;
outHSlices.reserve(outNumHSlices);
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
RankedTensorType currOutHSliceType = outHSliceType;
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
currOutHSliceType = outLastHSliceType;
SmallVector<Value> partialResults;
partialResults.reserve(coresPerVSlice);
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
SmallVector<Value> weights;
weights.reserve(aHSlices[coreId].size());
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
});
partialResults.push_back(computeOp.getResult(0));
}
if (hasC) {
Value cHSlice = cHSlices[outSliceId];
partialResults.push_back(cHSlice);
}
auto reduceComputeOp =
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) {
SmallVector<Value> values(blockArgs.begin(), blockArgs.end());
Value outHSlice = sumTensors(values, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
});
outHSlices.push_back(reduceComputeOp.getResult(0));
}
auto concatComputeOp =
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}
} // namespace onnx_mlir