#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static FailureOr materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) { if (factor == 1.0f) return value; auto constantOp = value.getDefiningOp(); if (!constantOp) return failure(); auto denseAttr = dyn_cast(constantOp.getValue()); if (!denseAttr) return failure(); SmallVector scaledValues; scaledValues.reserve(denseAttr.getNumElements()); APFloat scale(factor); bool hadFailure = false; for (const APFloat& originalValue : denseAttr.getValues()) { APFloat scaledValue(originalValue); if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven)) hadFailure = true; scaledValues.push_back(std::move(scaledValue)); } if (hadFailure) return failure(); auto scaledAttr = DenseFPElementsAttr::get(cast(denseAttr.getType()), scaledValues); return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); } struct GemmToManyGemv : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; struct GemvToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; } // namespace LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto outType = cast(gemmOp.getY().getType()); assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); const int64_t numOutRows = aType.getDimSize(0); // Only decompose when there are multiple rows to split if (numOutRows <= 1) return failure(); auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); if (failed(scaledB)) return failure(); b = *scaledB; RankedTensorType cType = nullptr; bool cHasNumOutRows = false; if (hasC) { auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); if (failed(scaledC)) return failure(); c = *scaledC; cType = cast(c.getType()); // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector { {0, 1} }); cType = expandedType; } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); cHasNumOutRows = cType.getDimSize(0) == numOutRows; } auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); SmallVector gemvOps; gemvOps.reserve(numOutRows); for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult(); Value cSlice = c; if (hasC) { if (cHasNumOutRows) { SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); } else assert("C should be a vector" && isVectorShape(getTensorShape(c))); } auto gemvOp = ONNXGemmOp::create(rewriter, loc, outRowType, aSlice, b, cSlice, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), gemmOp.getTransAAttr(), gemmOp.getTransBAttr()); gemvOps.push_back(gemvOp.getY()); } auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs); spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location gemmLoc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); Value out = gemmOp.getY(); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); float beta = gemmOpAdaptor.getBeta().convertToFloat(); bool transA = gemmOpAdaptor.getTransA(); bool transB = gemmOpAdaptor.getTransB(); auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto outType = cast(out.getType()); RankedTensorType cType = nullptr; bool hasC = !isa(c.getDefiningOp()); if (hasC) { cType = cast(c.getType()); // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector { {0, 1} }); cType = expandedType; } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); } assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) // Not a gemv return failure(); if (transA) { auto aShape = aType.getShape(); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); } if (transB) { auto bShape = bType.getShape(); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); bType = cast(b.getType()); } if (alpha != 1.0f) { auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc); if (failed(scaledB)) return failure(); b = *scaledB; bType = cast(b.getType()); alpha = 1.0f; } if (hasC && beta != 1.0f) { auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc); if (failed(scaledC)) return failure(); c = *scaledC; cType = cast(c.getType()); beta = 1.0f; } auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); auto bNumVSlices = aNumHSlices; auto bLastVSliceSize = aLastHSliceSize; auto cNumHSlices = bNumHSlices; auto cLastHSliceSize = bLastHSliceSize; auto outNumHSlices = cNumHSlices; auto outLastHSliceSize = cLastHSliceSize; const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); DenseMap>> bTiles = tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); SmallVector cHSlices; if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); if (hasC) cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); RankedTensorType outHSliceType = RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); RankedTensorType outLastHSliceType = RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); SmallVector outHSlices; outHSlices.reserve(outNumHSlices); for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { RankedTensorType currOutHSliceType = outHSliceType; if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) currOutHSliceType = outLastHSliceType; SmallVector partialResults; partialResults.reserve(coresPerVSlice); for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { SmallVector weights; weights.reserve(aHSlices[coreId].size()); for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) weights.push_back(bTiles[outSliceId][coreId][aSliceId]); auto computeOp = createSpatCompute( rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { SmallVector vmmOutputs; vmmOutputs.reserve(aHSlicesArgs.size()); for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) vmmOutputs.push_back( spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); Value partialVmmSum = sumTensors(vmmOutputs, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); }); partialResults.push_back(computeOp.getResult(0)); } if (hasC) { Value cHSlice = cHSlices[outSliceId]; partialResults.push_back(cHSlice); } auto reduceComputeOp = createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) { SmallVector values(blockArgs.begin(), blockArgs.end()); Value outHSlice = sumTensors(values, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); }); outHSlices.push_back(reduceComputeOp.getResult(0)); } auto concatComputeOp = createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); patterns.insert(ctx); } } // namespace onnx_mlir