normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-30 16:37:28 +02:00
parent 7c3943bd06
commit ab63498f3f
14 changed files with 340 additions and 278 deletions
@@ -1,11 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/APInt.h"
#include <algorithm>
#include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
using namespace mlir;
@@ -47,46 +42,4 @@ FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> ax
return normalizedAxes;
}
Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
}
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, anchorOp, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
}
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
}
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
}
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
return getOrCreateIndexConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
return cast<Value>(value);
}
} // namespace onnx_mlir
@@ -1,11 +1,7 @@
#pragma once
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
@@ -21,21 +17,4 @@ int64_t normalizeIndex(int64_t index, int64_t dimSize);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange operands);
mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter,
mlir::Operation* anchorOp,
mlir::Value value,
int64_t multiplier);
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value);
} // namespace onnx_mlir
@@ -11,6 +11,7 @@
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -27,8 +28,7 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(),
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
}
static bool isStaticTensorResult(Operation* op) {
@@ -14,6 +14,7 @@
#include <utility>
#include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -60,8 +61,11 @@ static Value createGemmBatchKOffset(
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
return createOrFoldAffineApply(rewriter,
loc,
(d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
}
static Value createGemmBatchHOffset(Value lane,
@@ -75,8 +79,11 @@ static Value createGemmBatchHOffset(Value lane,
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
return createOrFoldAffineApply(rewriter,
loc,
d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
}
static Value
@@ -259,7 +266,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value row =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp());
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -297,7 +305,8 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
return createOrFoldAffineApply(
rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static Value extractDynamicGemmBColumn(
@@ -429,7 +438,8 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
Value column =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -477,7 +487,8 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols);
Value column =
onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -515,8 +526,11 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) {
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
return createOrFoldAffineApply(rewriter,
loc,
d0 * (numKSlices * numOutRows) + kSlice * numOutRows,
ValueRange {hSlice},
rewriter.getInsertionBlock()->getParentOp());
}
static Value extractReductionPiece(Value partialPiecesArg,
@@ -597,8 +611,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = onnx_mlir::multiplyIndexByConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
Value hOffset = onnx_mlir::affineMulConst(
rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice =
@@ -7,6 +7,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -276,7 +277,8 @@ static Value extractBatchedBTile(Value b,
static Value getBatchLaneIndex(
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices);
return affineFloorDivConst(
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
}
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
@@ -300,16 +302,15 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices);
Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices);
Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices);
Value kOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue());
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp);
Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp);
Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp);
auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
@@ -445,10 +446,11 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
ValueRange {},
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -491,10 +493,11 @@ static Value createBatchedDynamicOutputCompute(Value scalarPieces,
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = tensor::ExtractSliceOp::create(
@@ -542,10 +545,9 @@ static Value extractBatchedReductionPiece(Value partialPiecesArg,
int64_t numOutRows,
PatternRewriter& rewriter,
Location loc) {
Value batchOffset = multiplyIndexByConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices);
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp);
Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp);
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
@@ -631,7 +633,7 @@ static Value createBatchedReductionCompute(Value partialPieces,
{2}
});
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
@@ -7,6 +7,7 @@
#include <algorithm>
#include <numeric>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -82,7 +83,7 @@ computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternR
expr = expr.floorDiv(stride);
if (dimSize != 1)
expr = expr % dimSize;
return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane});
return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,