less affine code and better affine helpers
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-29 14:34:31 +02:00
parent f492400eda
commit 4a98e88e97
15 changed files with 173 additions and 142 deletions
@@ -5,5 +5,6 @@
#include "MatrixProductLowering.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -77,65 +77,4 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewri
return slicesPerCore;
}
Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<int64_t> resultShape(sourceType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto sourceType = cast<RankedTensorType>(source.getType());
size_t rank = static_cast<size_t>(sourceType.getRank());
bool isIdentitySlice =
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
&& strides.size() == rank;
if (isIdentitySlice) {
ArrayRef<int64_t> sourceShape = sourceType.getShape();
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
isIdentitySlice = false;
break;
}
}
}
if (isIdentitySlice)
return source;
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
}
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
return tensor::InsertSliceOp::create(rewriter,
loc,
source,
dest,
offsets,
getStaticSizes(rewriter, sourceType.getShape()),
getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
} // namespace onnx_mlir
@@ -28,21 +28,4 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> sizes,
llvm::ArrayRef<mlir::OpFoldResult> strides);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
} // namespace onnx_mlir
@@ -1184,48 +1184,6 @@ static Value createZeroPaddedTensor(Value value,
return padOp.getResult();
}
static Value affineAddConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) {
if (offset == 0)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
}
static Value affineMulConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
if (factor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
}
static Value affineFloorDivConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(divisor > 0 && "expected positive affine floordiv divisor");
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
static Value affineModConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
assert(modulus > 0 && "expected positive affine mod divisor");
if (modulus == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
}
static Value createConvInputPatch(Value input,
RankedTensorType patchType,
Value batchIndex,
@@ -2316,11 +2274,10 @@ static Value createIm2colRows(const ConvLoweringState& state,
ValueRange {im2colInit},
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value im2colAcc = iterArgs.front();
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
Value batchIndex =
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
Value batchPatchIndex =
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value inputHeightOffset =