simplify affine maps to constants where possible
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 16:39:27 +02:00
parent 1a5d7d2a3f
commit 4bdaa57656
2 changed files with 65 additions and 5 deletions
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
@@ -72,9 +73,37 @@ static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t va
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
}
static std::optional<int64_t> getConstantIndexValue(Value value) {
if (auto constantIndex = value.getDefiningOp<arith::ConstantIndexOp>())
return constantIndex.value();
APInt constantValue;
if (matchPattern(value, m_ConstantInt(&constantValue)))
return constantValue.getSExtValue();
return std::nullopt;
}
static Value
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = getConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (constantResult)
return createIndexConstant(rewriter, constantResult.getInt());
}
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
@@ -115,7 +144,15 @@ static Value createGemmBatchKOffset(
}
static Value createGemmBatchHOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
Value lane,
int64_t numOutRows,
int64_t numKSlices,
int64_t numOutHSlices,
ConversionPatternRewriter& rewriter,
Location loc) {
if (numOutHSlices == 1)
return createIndexConstant(rewriter, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(
@@ -290,6 +327,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
RankedTensorType partialPiecesType,
int64_t numOutRows,
int64_t numKSlices,
int64_t numOutHSlices,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0);
@@ -314,7 +352,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
@@ -610,7 +648,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
auto partialPiecesType =
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
auto batchOp = createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, rewriter, loc);
auto batchOp =
createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
auto reductionCompute = createReductionCompute(
batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);