finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled

use uniqued constant helpers everywhere
materialize transposed constants directly
This commit is contained in:
NiccoloN
2026-05-29 17:05:45 +02:00
parent 819d8af0f7
commit 8bb0babf1b
32 changed files with 300 additions and 467 deletions
@@ -10,6 +10,7 @@
#include "ShapeTilingUtils.hpp"
#include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
@@ -19,10 +20,6 @@ using namespace mlir;
namespace onnx_mlir {
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
return getOrMaterializeIndexValue(rewriter, loc, result);
}
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
APInt lhsConst;
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
@@ -43,11 +40,12 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
if (factorConst.isZero())
return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
if (factorConst.isOne())
return value;
auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult();
auto factorValue =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue());
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
}
@@ -61,8 +59,6 @@ int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
@@ -226,49 +222,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
return slicesPerCore;
}
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
size_t numHSlices = hSlices.size();
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
Value hSlice = hSlices[hSliceId];
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
size_t coreId = vSliceId / crossbarCountInCore;
Value vSlice = vSlices[vSliceId];
tiles[hSliceId][coreId].push_back(vSlice);
}
}
return tiles;
}
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto buildBroadcast = [&](Value input) -> Value {
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
if (isCompileTimeComputable(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
}
Value materializeContiguousTensorSlice(Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets,
@@ -294,7 +247,7 @@ Value materializeContiguousTensorSlice(Value source,
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SmallVector<Value> zeroIndices(resultType.getRank());
for (Value& zeroIndex : zeroIndices)
zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
SmallVector<Value> resultIndices;
resultIndices.reserve(resultType.getRank());
@@ -304,7 +257,7 @@ Value materializeContiguousTensorSlice(Value source,
SmallVector<Value> sourceIndices;
sourceIndices.reserve(resultType.getRank());
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
Value offsetValue = getIndexValue(offsets[idx], rewriter, loc);
Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]);
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
}
@@ -337,8 +290,8 @@ Value materializeContiguousTensorSlice(Value source,
}
Value lower = zeroIndices[dim];
Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult();
Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult();
Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
resultIndices.push_back(loop.getInductionVar());
@@ -352,17 +305,6 @@ Value materializeContiguousTensorSlice(Value source,
return buildLoopNest(buildLoopNest, 0, init);
}
Value extractStaticSlice(PatternRewriter& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets) {
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
getUnitStrides(rewriter, resultType.getRank()))
.getResult();
}
Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType());