Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -0,0 +1,23 @@
#include "AttributeUtils.hpp"
#include "mlir/IR/BuiltinAttributes.h"
using namespace mlir;
namespace onnx_mlir {
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
return attr ? getI64Attr(*attr, index) : defaultValue;
}
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
llvm::SmallVector<int64_t> values;
values.reserve(attr.size());
for (Attribute value : attr)
values.push_back(cast<IntegerAttr>(value).getInt());
return values;
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
namespace onnx_mlir {
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
} // namespace onnx_mlir
@@ -1,6 +1,8 @@
#pragma once
#include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -7,9 +7,13 @@
#include <cassert>
#include <cstddef>
#include <limits>
#include <type_traits>
#include <utility>
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
@@ -49,6 +53,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
struct SpatComputeBatchBodyArgs {
mlir::Value lane;
mlir::ValueRange weights;
mlir::ValueRange inputs;
mlir::ValueRange outputs;
};
} // namespace detail
template <typename RewriterT>
@@ -159,6 +170,96 @@ auto createSpatCompute(RewriterT& rewriter,
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
for (mlir::Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(weight.getLoc());
}
for (mlir::Value input : inputs) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(input.getLoc());
}
for (mlir::Type resultType : resultTypes) {
blockArgTypes.push_back(resultType);
blockArgLocs.push_back(loc);
}
auto* block =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(block);
detail::SpatComputeBatchBodyArgs args {
block->getArgument(0),
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())
};
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
mlir::ArrayRef<mlir::OpFoldResult> offsets,
mlir::ArrayRef<mlir::OpFoldResult> sizes,
mlir::ArrayRef<mlir::OpFoldResult> strides) {
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
}
template <typename BodyFn>
mlir::Value materializeOrComputeUnary(mlir::Value input,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc,
BodyFn&& build) {
auto&& buildFn = build;
if (isCompileTimeComputable(input))
return buildFn(input);
auto computeOp =
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
mlir::Value result = buildFn(computeInput);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -0,0 +1,104 @@
#include "IndexingUtils.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/APInt.h"
#include <algorithm>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
int64_t normalizedAxis = normalizeAxis(axis, rank);
if (normalizedAxis < 0 || normalizedAxis >= rank)
return failure();
return normalizedAxis;
}
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; ++axis)
normalizedAxes.push_back(axis);
}
else {
normalizedAxes.reserve(axesAttr->size());
for (Attribute attr : *axesAttr)
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
}
return normalizedAxes;
}
SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
return normalizeAxesImpl(std::optional<ArrayAttr>(axesAttr), rank);
}
SmallVector<int64_t> normalizeAxes(std::optional<ArrayAttr> axesAttr, int64_t rank) {
return normalizeAxesImpl(axesAttr, rank);
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes)
if (axis < 0 || axis >= rank)
return failure();
return normalizedAxes;
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) {
return normalizeAxesChecked(std::optional<ArrayAttr>(axesAttr), rank);
}
Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
}
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
if (multiplier == 0)
return getOrCreateHostIndexConstant(rewriter, anchorOp, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
}
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value});
}
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
}
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
return cast<Value>(value);
}
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#pragma once
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank);
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize);
llvm::SmallVector<int64_t> normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank);
llvm::SmallVector<int64_t> normalizeAxes(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange operands);
mlir::Value
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value);
} // namespace onnx_mlir
@@ -6,20 +6,21 @@
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <functional>
#include "ShapeTilingUtils.hpp"
#include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
if (auto attr = dyn_cast<Attribute>(result))
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
return cast<Value>(result);
return getOrMaterializeIndexValue(rewriter, loc, result);
}
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
@@ -50,6 +51,84 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
}
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); }
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
Value transposeMaybeInCompute(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
PatternRewriter& rewriter,
Location loc) {
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
};
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
@@ -88,11 +167,8 @@ SmallVector<Value> sliceTensor(
assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (const auto size : shape)
sizes.push_back(rewriter.getIndexAttr(size));
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis];
@@ -276,4 +352,43 @@ Value materializeContiguousTensorSlice(Value source,
return buildLoopNest(buildLoopNest, 0, init);
}
Value extractStaticSlice(PatternRewriter& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets) {
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
getUnitStrides(rewriter, resultType.getRank()))
.getResult();
}
Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<int64_t> resultShape(sourceType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
return tensor::InsertSliceOp::create(rewriter,
loc,
source,
dest,
offsets,
getStaticSizes(rewriter, sourceType.getShape()),
getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
} // namespace onnx_mlir
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
@@ -11,6 +12,7 @@
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
@@ -109,6 +111,33 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
&& lhsType.getShape() == rhsType.getShape();
}
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
int64_t getStaticShapeElementCount(mlir::RankedTensorType type);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
mlir::Value transposeMaybeInCompute(mlir::Value value,
mlir::RankedTensorType resultType,
mlir::ArrayRef<int64_t> permutation,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
@@ -148,4 +177,23 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
int64_t axis,
int64_t offset,
int64_t size);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
} // namespace onnx_mlir