normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-30 16:37:28 +02:00
parent 7c3943bd06
commit ab63498f3f
14 changed files with 340 additions and 278 deletions
@@ -14,6 +14,7 @@
#include <utility>
#include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -60,8 +61,11 @@ static Value createGemmBatchKOffset(
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
return createOrFoldAffineApply(rewriter,
loc,
(d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
}
static Value createGemmBatchHOffset(Value lane,
@@ -75,8 +79,11 @@ static Value createGemmBatchHOffset(Value lane,
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
return createOrFoldAffineApply(rewriter,
loc,
d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
}
static Value
@@ -259,7 +266,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value row =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp());
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -297,7 +305,8 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
return createOrFoldAffineApply(
rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static Value extractDynamicGemmBColumn(
@@ -429,7 +438,8 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
Value column =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -477,7 +487,8 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols);
Value column =
onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -515,8 +526,11 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) {
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
return createOrFoldAffineApply(rewriter,
loc,
d0 * (numKSlices * numOutRows) + kSlice * numOutRows,
ValueRange {hSlice},
rewriter.getInsertionBlock()->getParentOp());
}
static Value extractReductionPiece(Value partialPiecesArg,
@@ -597,8 +611,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = onnx_mlir::multiplyIndexByConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
Value hOffset = onnx_mlir::affineMulConst(
rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice =
@@ -7,6 +7,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -276,7 +277,8 @@ static Value extractBatchedBTile(Value b,
static Value getBatchLaneIndex(
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices);
return affineFloorDivConst(
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
}
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
@@ -300,16 +302,15 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices);
Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices);
Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices);
Value kOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue());
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp);
Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp);
Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp);
auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
@@ -445,10 +446,11 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
ValueRange {},
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -491,10 +493,11 @@ static Value createBatchedDynamicOutputCompute(Value scalarPieces,
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = tensor::ExtractSliceOp::create(
@@ -542,10 +545,9 @@ static Value extractBatchedReductionPiece(Value partialPiecesArg,
int64_t numOutRows,
PatternRewriter& rewriter,
Location loc) {
Value batchOffset = multiplyIndexByConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices);
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp);
Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp);
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
@@ -631,7 +633,7 @@ static Value createBatchedReductionCompute(Value partialPieces,
{2}
});
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
@@ -7,6 +7,7 @@
#include <algorithm>
#include <numeric>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -82,7 +83,7 @@ computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternR
expr = expr.floorDiv(stride);
if (dimSize != 1)
expr = expr % dimSize;
return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane});
return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,