Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
+11 -11
View File
@@ -28,7 +28,7 @@ Block* getHostConstantBlock(Operation* anchorOp) {
return anchorOp->getBlock(); return anchorOp->getBlock();
} }
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) { Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation"); assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp); Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) { for (Operation& op : *hostBlock) {
@@ -42,7 +42,7 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
} }
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) { Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation"); assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp); Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) { for (Operation& op : *hostBlock) {
@@ -57,28 +57,28 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, R
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult(); return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
} }
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) { Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder); return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
} }
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) { Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext()); Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder); return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() );
} }
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) { Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext()); Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter); return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
} }
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) { Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext()); Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder); return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() );
} }
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) { Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext()); Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder); return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() );
} }
Value createAffineApplyOrFoldedConstant( Value createAffineApplyOrFoldedConstant(
@@ -95,7 +95,7 @@ Value createAffineApplyOrFoldedConstant(
SmallVector<Attribute> foldedResults; SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) { if (succeeded(map.constantFold(operandConstants, foldedResults))) {
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front())) if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateHostIndexConstant(anchorOp, constantResult.getInt(), rewriter); return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt());
} }
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
+11 -11
View File
@@ -10,25 +10,25 @@ namespace onnx_mlir {
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp); mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, mlir::Value getOrCreateHostConstant(mlir::OperationFolder& folder,
mlir::Operation* anchorOp,
mlir::Attribute value, mlir::Attribute value,
mlir::Type type, mlir::Type type);
mlir::OperationFolder& folder);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, mlir::Value getOrCreateHostConstant(mlir::RewriterBase& rewriter,
mlir::Operation* anchorOp,
mlir::Attribute value, mlir::Attribute value,
mlir::Type type, mlir::Type type);
mlir::RewriterBase& rewriter);
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder); mlir::Value getOrCreateHostConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); mlir::Value getOrCreateHostIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter); mlir::Value getOrCreateHostIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder); mlir::Value getOrCreateHostI32Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int32_t value);
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); mlir::Value getOrCreateHostI64Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
mlir::Location loc, mlir::Location loc,
@@ -25,7 +25,9 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
Common/ShapeTilingUtils.cpp Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp Common/WeightMaterialization.cpp
@@ -0,0 +1,23 @@
#include "AttributeUtils.hpp"
#include "mlir/IR/BuiltinAttributes.h"
using namespace mlir;
namespace onnx_mlir {
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
return attr ? getI64Attr(*attr, index) : defaultValue;
}
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
llvm::SmallVector<int64_t> values;
values.reserve(attr.size());
for (Attribute value : attr)
values.push_back(cast<IntegerAttr>(value).getInt());
return values;
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
namespace onnx_mlir {
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
} // namespace onnx_mlir
@@ -1,6 +1,8 @@
#pragma once #pragma once
#include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp" #include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp" #include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -7,9 +7,13 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <limits>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -49,6 +53,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
template <typename Fn> template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>; using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
struct SpatComputeBatchBodyArgs {
mlir::Value lane;
mlir::ValueRange weights;
mlir::ValueRange inputs;
mlir::ValueRange outputs;
};
} // namespace detail } // namespace detail
template <typename RewriterT> template <typename RewriterT>
@@ -159,6 +170,96 @@ auto createSpatCompute(RewriterT& rewriter,
} }
} }
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
for (mlir::Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(weight.getLoc());
}
for (mlir::Value input : inputs) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(input.getLoc());
}
for (mlir::Type resultType : resultTypes) {
blockArgTypes.push_back(resultType);
blockArgLocs.push_back(loc);
}
auto* block =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(block);
detail::SpatComputeBatchBodyArgs args {
block->getArgument(0),
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())
};
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
mlir::ArrayRef<mlir::OpFoldResult> offsets,
mlir::ArrayRef<mlir::OpFoldResult> sizes,
mlir::ArrayRef<mlir::OpFoldResult> strides) {
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
}
template <typename BodyFn>
mlir::Value materializeOrComputeUnary(mlir::Value input,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc,
BodyFn&& build) {
auto&& buildFn = build;
if (isCompileTimeComputable(input))
return buildFn(input);
auto computeOp =
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
mlir::Value result = buildFn(computeInput);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter); mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,104 @@
#include "IndexingUtils.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/APInt.h"
#include <algorithm>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
int64_t normalizedAxis = normalizeAxis(axis, rank);
if (normalizedAxis < 0 || normalizedAxis >= rank)
return failure();
return normalizedAxis;
}
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; ++axis)
normalizedAxes.push_back(axis);
}
else {
normalizedAxes.reserve(axesAttr->size());
for (Attribute attr : *axesAttr)
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
}
return normalizedAxes;
}
SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
return normalizeAxesImpl(std::optional<ArrayAttr>(axesAttr), rank);
}
SmallVector<int64_t> normalizeAxes(std::optional<ArrayAttr> axesAttr, int64_t rank) {
return normalizeAxesImpl(axesAttr, rank);
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes)
if (axis < 0 || axis >= rank)
return failure();
return normalizedAxes;
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) {
return normalizeAxesChecked(std::optional<ArrayAttr>(axesAttr), rank);
}
Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
}
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
if (multiplier == 0)
return getOrCreateHostIndexConstant(rewriter, anchorOp, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
}
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value});
}
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
}
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
return cast<Value>(value);
}
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#pragma once
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank);
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize);
llvm::SmallVector<int64_t> normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank);
llvm::SmallVector<int64_t> normalizeAxes(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange operands);
mlir::Value
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value);
} // namespace onnx_mlir
@@ -6,20 +6,21 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <functional>
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) { static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
if (auto attr = dyn_cast<Attribute>(result)) return getOrMaterializeIndexValue(rewriter, loc, result);
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
return cast<Value>(result);
} }
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) { static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
@@ -50,6 +51,84 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult(); return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
} }
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); }
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
Value transposeMaybeInCompute(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
PatternRewriter& rewriter,
Location loc) {
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
};
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) { static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
auto sourceType = dyn_cast<RankedTensorType>(source.getType()); auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank()) if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
@@ -88,11 +167,8 @@ SmallVector<Value> sliceTensor(
assert("Invalid axis" && axis < shape.size()); assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
sizes.reserve(shape.size());
for (const auto size : shape)
sizes.push_back(rewriter.getIndexAttr(size));
sizes[axis] = rewriter.getIndexAttr(sliceSize); sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis]; long length = shape[axis];
@@ -276,4 +352,43 @@ Value materializeContiguousTensorSlice(Value source,
return buildLoopNest(buildLoopNest, 0, init); return buildLoopNest(buildLoopNest, 0, init);
} }
Value extractStaticSlice(PatternRewriter& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets) {
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
getUnitStrides(rewriter, resultType.getRank()))
.getResult();
}
Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<int64_t> resultShape(sourceType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
return tensor::InsertSliceOp::create(rewriter,
loc,
source,
dest,
offsets,
getStaticSizes(rewriter, sourceType.getShape()),
getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
@@ -11,6 +12,7 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <optional>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
@@ -109,6 +111,33 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
&& lhsType.getShape() == rhsType.getShape(); && lhsType.getShape() == rhsType.getShape();
} }
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
int64_t getStaticShapeElementCount(mlir::RankedTensorType type);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
mlir::Value transposeMaybeInCompute(mlir::Value value,
mlir::RankedTensorType resultType,
mlir::ArrayRef<int64_t> permutation,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of /// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements. /// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice, llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
@@ -148,4 +177,23 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc); mlir::Location loc);
mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
int64_t axis,
int64_t offset,
int64_t size);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override; ConversionPatternRewriter& rewriter) const override;
}; };
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType()); auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1) if (biasType.getRank() != 1)
@@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return failure(); return failure();
} }
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1; const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
int64_t padHeightBegin = 0; int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0; int64_t padHeightEnd = 0;
@@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
int64_t padWidthEnd = 0; int64_t padWidthEnd = 0;
if (padsAttr) { if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0); padHeightBegin = getI64Attr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1); padWidthBegin = getI64Attr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2); padHeightEnd = getI64Attr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3); padWidthEnd = getI64Attr(*padsAttr, 3);
} }
else { else {
// Compute padding from auto_pad attribute // Compute padding from auto_pad attribute
@@ -13,7 +13,7 @@
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value,
ArrayRef<int64_t> permutation, ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
if (isCompileTimeComputable(value)) return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
}
static Value
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
} }
static Value static Value
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) { multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
if (multiplier == 0) return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
return createIndexConstant(rewriter, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value});
} }
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) { static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
if (divisor == 1) return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
return createIndexConstant(rewriter, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value});
} }
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
@@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter
static Value createGemmBatchKOffset( static Value createGemmBatchKOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
if (numKSlices == 1) if (numKSlices == 1)
return createIndexConstant(rewriter, 0); return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply( return createAffineApplyOrConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
} }
@@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
if (numOutHSlices == 1) if (numOutHSlices == 1)
return createIndexConstant(rewriter, 0); return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply( return createAffineApplyOrConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
} }
@@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0); const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = spatial::SpatComputeBatch::create(rewriter, auto batchOp = createSpatComputeBatch(
loc, rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
TypeRange {partialPiecesType}, Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
ValueRange {b}, Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
ValueRange {a});
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType}; auto aTileType =
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc); RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
Block* body = auto bTileType = RankedTensorType::get(
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); {static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
rewriter.setInsertionPointToEnd(body); paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
auto lane = batchOp.getLaneArgument(); SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
auto weight = batchOp.getWeightArgument(0); SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
auto input = batchOp.getInputArgument(0); rewriter.getIndexAttr(crossbarSize.getValue())};
auto output = batchOp.getOutputArgument(0); SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
assert(lane && weight && input && output && "malformed Gemm compute_batch body"); Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
.getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc); SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc); SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType()); });
auto bTileType = RankedTensorType::get( assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())}, return *batchOp;
paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc);
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> pieceOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
} }
static Value createDynamicGemmBatchRow( static Value createDynamicGemmBatchRow(
@@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow(
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
} }
static Value createDynamicGemmBatchColumn( static Value createDynamicGemmBatchColumn(
@@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
const int64_t numOutCols = outType.getDimSize(1); const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols; const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = spatial::SpatComputeBatch::create(rewriter, auto batchOp = createSpatComputeBatch(
loc, rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
TypeRange {scalarPiecesType}, Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
ValueRange {},
ValueRange {a, b});
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType}; auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Block* body = Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); Value bVector = bAlreadyTransposed
rewriter.setInsertionPointToEnd(body); ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto lane = batchOp.getLaneArgument(); SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
auto inputA = batchOp.getInputArgument(0); SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto inputB = batchOp.getInputArgument(1); SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
auto output = batchOp.getOutputArgument(0); createParallelInsertSliceIntoBatchOutput(
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body"); rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
});
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc); assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc); return *batchOp;
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
} }
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
@@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value biasArg = bias ? blockArgs[1] : Value(); Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = createIndexConstant(rewriter, 0); Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = createIndexConstant(rewriter, 1); Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = createIndexConstant(rewriter, laneCount); Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
@@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) { Location loc) {
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); return createAffineApplyOrConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
} }
static Value extractReductionPiece(Value partialPiecesArg, static Value extractReductionPiece(Value partialPiecesArg,
@@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
Value paddedOutput = outputInit; Value paddedOutput = outputInit;
if (numOutHSlices == 1) { if (numOutHSlices == 1) {
Value hSlice = createIndexConstant(rewriter, 0); Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
paddedOutput = buildOutputSlice(outputInit, hSlice); paddedOutput = buildOutputSlice(outputInit, hSlice);
} }
else { else {
Value c0 = createIndexConstant(rewriter, 0); Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = createIndexConstant(rewriter, 1); Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices); Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(hLoop.getBody()); rewriter.setInsertionPointToStart(hLoop.getBody());
@@ -19,14 +19,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape, static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) { ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty()) if (lhsBatchShape.empty())
@@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
auto buildCollapsed = [&](Value input) -> Value { auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation); return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
}; };
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
if (isCompileTimeComputable(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
} }
static Value static Value
@@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
for (size_t dim = 0; dim < batchRank; ++dim) for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim)); reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute = auto buildExpanded = [&](Value input) -> Value {
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) { return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation); };
spatial::SpatYieldOp::create(rewriter, loc, expanded); return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
});
return expandCompute.getResult(0);
} }
static Value extractBatchMatrix(Value value, static Value extractBatchMatrix(Value value,
@@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value,
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = { SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value { auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
@@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value,
}); });
}; };
if (isCompileTimeComputable(value)) return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
} }
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
@@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
perm = {0, 2, 1}; perm = {0, 2, 1};
} }
auto buildTranspose = [&](Value input) -> Value { return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isCompileTimeComputable(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
} }
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -166,10 +130,11 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
perm = {0, 2, 1}; perm = {0, 2, 1};
} }
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { auto transposeCompute =
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, transposed); Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
}); spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0); return transposeCompute.getResult(0);
} }
@@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure(); return failure();
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure(); return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|| !haveStaticPositiveShape(outType.getShape()))
return failure(); return failure();
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
@@ -1,9 +1,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
@@ -16,26 +18,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; axis++)
normalizedAxes.push_back(axis);
return normalizedAxes;
}
normalizedAxes.reserve(axesAttr.size());
for (Attribute attr : axesAttr) {
int64_t axis = cast<IntegerAttr>(attr).getInt();
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) { static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false); SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) { for (int64_t axis : axes) {
@@ -50,6 +32,181 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType); return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
} }
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? 1 : dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
if (!isReduced)
shape.push_back(dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? dim : 1);
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
}
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
shape.front() = laneCount;
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
}
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> keptAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
if (!isReduced)
keptAxes.push_back(static_cast<int64_t>(axis));
return keptAxes;
}
static Value computeLaneIndex(Value lane,
int64_t stride,
int64_t dimSize,
ConversionPatternRewriter& rewriter,
Location loc) {
if (dimSize == 1)
return arith::ConstantIndexOp::create(rewriter, loc, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr expr = d0;
if (stride != 1)
expr = expr.floorDiv(stride);
if (dimSize != 1)
expr = expr % dimSize;
return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane});
}
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
ArrayRef<bool> reducedAxes,
RankedTensorType batchType,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
auto sliceType = getReducedSliceType(inputType, reducedAxes);
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
int64_t laneCount = 1;
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
keptAxisStrides[index] = laneCount;
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
if (dimSize <= 0)
return failure();
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
return failure();
laneCount *= dimSize;
}
SmallVector<OpFoldResult> sliceOffsets;
SmallVector<OpFoldResult> sliceSizes;
SmallVector<OpFoldResult> insertOffsets;
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
sliceOffsets.reserve(inputType.getRank());
sliceSizes.reserve(inputType.getRank());
insertOffsets.reserve(inputType.getRank());
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) {
size_t keptAxisIndex = 0;
sliceOffsets.clear();
sliceSizes.clear();
insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
Value axisIndex =
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex;
sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1));
}
insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
Value slice =
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return (*batchOp).getResult(0);
}
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
RankedTensorType keepdimsType,
RankedTensorType compactKeptType,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
auto batchType = cast<RankedTensorType>(batchValue.getType());
if (batchType == keepdimsType)
return batchValue;
SmallVector<ReassociationIndices> collapseToFlat {{}};
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
collapseToFlat.front().push_back(axis);
SmallVector<ReassociationIndices> expandFlatToCompact(1);
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
expandFlatToCompact.front().push_back(axis);
SmallVector<ReassociationIndices> expandCompactToKeepdims;
ReassociationIndices pendingLeadingReducedAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
if (expandCompactToKeepdims.empty())
pendingLeadingReducedAxes.push_back(axis);
else
expandCompactToKeepdims.back().push_back(axis);
continue;
}
expandCompactToKeepdims.emplace_back();
auto& group = expandCompactToKeepdims.back();
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
pendingLeadingReducedAxes.clear();
group.push_back(axis);
}
if (!pendingLeadingReducedAxes.empty())
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
auto reshapeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
Value compact = flat;
if (compactKeptType != flatType)
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
Value keepdims = compact;
if (keepdimsType != compactKeptType)
keepdims =
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
});
return reshapeCompute.getResult(0);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) { static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
SmallVector<ReassociationIndices> reassociation; SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup; ReassociationIndices currentGroup;
@@ -72,56 +229,6 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation; return reassociation;
} }
static Value
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
});
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
if (axis == rank)
return createAverageCompute(input, leafType, rewriter, loc);
if (reducedAxes[axis])
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> reducedSlices;
reducedSlices.reserve(slices.size());
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue, static Value squeezeReducedAxes(Value keepdimsValue,
RankedTensorType resultType, RankedTensorType resultType,
ArrayRef<bool> reducedAxes, ArrayRef<bool> reducedAxes,
@@ -156,16 +263,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType()); auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
if (inputType.getRank() == 0) {
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
return success();
}
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank()); auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank());
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank()); if (failed(axes))
return failure();
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0) if (reducedAxes.empty() && inputType.getRank() != 0)
return failure(); return failure();
Location loc = reduceMeanOp.getLoc(); Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
int64_t laneCount = 1;
for (int64_t dim : compactKeptType.getShape())
laneCount *= dim;
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
auto lanePackedKeepdims =
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
if (failed(lanePackedKeepdims))
return failure();
Value reducedKeepdims = Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) { if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims); rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
@@ -23,28 +23,10 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType()); auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
} }
static Value static Value
@@ -197,12 +179,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
const int64_t inputWidth = xType.getDimSize(3); const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2); const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3); const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0); const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1); const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1); const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1); const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1); const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1); const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
int64_t padTop = 0; int64_t padTop = 0;
int64_t padLeft = 0; int64_t padLeft = 0;
@@ -212,10 +194,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (auto padsAttr = poolOp.getPads()) { if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4) if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements."); return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0); padTop = getI64Attr(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1); padLeft = getI64Attr(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2); padBottom = getI64Attr(*padsAttr, 2);
padRight = getI64(*padsAttr, 3); padRight = getI64Attr(*padsAttr, 3);
} }
else { else {
StringRef autoPad = poolOp.getAutoPad(); StringRef autoPad = poolOp.getAutoPad();
@@ -13,16 +13,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
static Value buildLoopSoftmaxSlice(Value input, static Value buildLoopSoftmaxSlice(Value input,
Value accumulator, Value accumulator,
RankedTensorType inputType, RankedTensorType inputType,
@@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input,
SmallVector<OpFoldResult> offsets; SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
offsets.reserve(rank); offsets.reserve(rank);
sizes.reserve(rank); sizes.reserve(rank);
@@ -110,44 +100,31 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
if (!inputType || !inputType.hasStaticShape()) if (!inputType || !inputType.hasStaticShape())
return failure(); return failure();
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank()); auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
if (axis < 0 || axis >= inputType.getRank()) if (failed(axis))
return failure(); return failure();
Value input = adaptor.getInput(); Value input = adaptor.getInput();
Value result; Value result;
if (axis == inputType.getRank() - 1) { if (*axis == inputType.getRank() - 1) {
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
} }
else { else {
SmallVector<int64_t> permutation; SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank()); permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
if (dim != axis) if (dim != *axis)
permutation.push_back(dim); permutation.push_back(dim);
permutation.push_back(axis); permutation.push_back(*axis);
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
SmallVector<int64_t> inversePermutation(inputType.getRank());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
auto transposedType = RankedTensorType::get( auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute = Value transposedInput =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
auto postTransposeCompute = result = transposeMaybeInCompute(
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
} }
rewriter.replaceOp(softmaxOp, result); rewriter.replaceOp(softmaxOp, result);
@@ -36,6 +36,14 @@ static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp()); return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
} }
struct PromotedOperands {
SmallVector<bool> promoteInput;
SmallVector<Value> newWeights;
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
};
template <typename ComputeOpTy> template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
@@ -48,60 +56,91 @@ static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
return false; return false;
} }
template <typename ComputeOpTy>
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
PromotedOperands promoted;
promoted.promoteInput.assign(compute.getInputs().size(), false);
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
promoted.newInputs.reserve(compute.getInputs().size());
promoted.newInputTypes.reserve(compute.getInputs().size());
promoted.newInputLocs.reserve(compute.getInputs().size());
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
goto keep_input;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
goto keep_input;
promoted.promoteInput[inputIdx] = true;
promoted.newWeights.push_back(input);
needsRewrite = true;
continue;
keep_input:
promoted.newInputs.push_back(input);
promoted.newInputTypes.push_back(input.getType());
promoted.newInputLocs.push_back(input.getLoc());
}
if (!needsRewrite)
return failure();
return promoted;
}
template <typename ComputeOpTy>
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
const PromotedOperands& promoted,
IRRewriter& bodyRewriter,
IRMapping& mapper,
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
PatternRewriter& rewriter) {
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
if (!promoted.promoteInput[oldInputIdx]) {
auto newInputArg = getNewInputArg(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
return success();
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> { struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern; using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false); auto promoted = computePromotedOperands(compute);
bool needsRewrite = false; if (failed(promoted))
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote"); return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute); rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute = auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes; SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs; SmallVector<Location> newBlockArgLocs;
for (Value weight : newWeights) { for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType()); newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc()); newBlockArgLocs.push_back(weight.getLoc());
} }
llvm::append_range(newBlockArgTypes, newInputTypes); llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs); llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
auto* newBlock = rewriter.createBlock( auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())}); {static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext()); IRRewriter bodyRewriter(rewriter.getContext());
@@ -115,24 +154,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite"); return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg); mapper.map(*oldWeightArg, *newWeightArg);
} }
size_t newInputIdx = 0; if (failed(mapPromotedInputArguments(
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) { compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
auto oldArg = compute.getInputArgument(oldInputIdx); return failure();
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator()) for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
@@ -156,63 +180,35 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern; using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false); auto promoted = computePromotedOperands(compute);
bool needsRewrite = false; if (failed(promoted))
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote"); return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute); rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute = auto newCompute =
spatial::SpatComputeBatch::create(rewriter, spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(), compute.getLoc(),
compute.getResultTypes(), compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())), rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights, promoted->newWeights,
newInputs); promoted->newInputs);
auto laneArg = compute.getLaneArgument(); auto laneArg = compute.getLaneArgument();
if (!laneArg) if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes; SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs; SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults()); newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults()); newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType()); newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc()); newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : newWeights) { for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType()); newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc()); newBlockArgLocs.push_back(weight.getLoc());
} }
llvm::append_range(newBlockArgTypes, newInputTypes); llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs); llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) { for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex); auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg) if (!outputArg)
@@ -224,7 +220,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
auto* newBlock = rewriter.createBlock( auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())}); {static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext()); IRRewriter bodyRewriter(rewriter.getContext());
@@ -242,29 +238,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite"); return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg); mapper.map(*oldWeightArg, *newWeightArg);
} }
size_t newInputIdx = 0; if (failed(mapPromotedInputArguments(
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) { compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
auto oldArg = compute.getInputArgument(oldInputIdx); return failure();
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) { for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex); auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg) if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite"); return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); mapper.map(*outputArg,
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
} }
for (Operation& op : oldBlock) for (Operation& op : oldBlock)
@@ -15,24 +15,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static Value
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
}
static Value concatGatherSlices(Value data, static Value concatGatherSlices(Value data,
int64_t axis, int64_t axis,
ArrayRef<int64_t> indices, ArrayRef<int64_t> indices,
@@ -45,7 +27,7 @@ static Value concatGatherSlices(Value data,
int64_t normalizedIndex = normalizeIndex(index, axisDim); int64_t normalizedIndex = normalizeIndex(index, axisDim);
if (normalizedIndex < 0 || normalizedIndex >= axisDim) if (normalizedIndex < 0 || normalizedIndex >= axisDim)
return {}; return {};
slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc)); slices.push_back(extractAxisSlice(rewriter, loc, data, axis, normalizedIndex, /*size=*/1));
} }
if (slices.empty()) if (slices.empty())
return {}; return {};
@@ -96,11 +78,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure(); return failure();
int64_t rank = dataType.getRank(); int64_t rank = dataType.getRank();
int64_t axis = normalizeAxis(gatherOp.getAxis(), rank); auto axis = normalizeAxisChecked(gatherOp.getAxis(), rank);
if (axis < 0 || axis >= rank) if (failed(axis))
return failure(); return failure();
int64_t axisDim = dataType.getShape()[axis]; int64_t axisDim = dataType.getShape()[*axis];
if (axisDim <= 0) if (axisDim <= 0)
return failure(); return failure();
@@ -116,7 +98,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
[&](Value data) -> LogicalResult { [&](Value data) -> LogicalResult {
Value result; Value result;
if (indicesType.getRank() == 1) { if (indicesType.getRank() == 1) {
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); result = concatGatherSlices(data, *axis, flatIndices, axisDim, rewriter, loc);
} }
else if (indicesType.getRank() == 2) { else if (indicesType.getRank() == 2) {
int64_t rowCount = indicesType.getShape()[0]; int64_t rowCount = indicesType.getShape()[0];
@@ -125,12 +107,13 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
rows.reserve(rowCount); rows.reserve(rowCount);
for (int64_t row = 0; row < rowCount; ++row) { for (int64_t row = 0; row < rowCount; ++row) {
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth); ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc); Value gatheredRow =
concatGatherSlices(data, *axis, rowIndices, axisDim, rewriter, loc);
if (!gatheredRow) if (!gatheredRow)
return failure(); return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); rows.push_back(addLeadingGatherDim(gatheredRow, *axis, rewriter, loc));
} }
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows); result = createSpatConcat(rewriter, loc, /*axis=*/*axis, rows);
} }
else { else {
return failure(); return failure();
@@ -14,10 +14,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape, static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) { SmallVector<ReassociationIndices>& reassociation) {
@@ -106,7 +102,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType()); auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape())) if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType))
return failure(); return failure();
if (sourceType == resultType) { if (sourceType == resultType) {
@@ -115,17 +111,8 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} }
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult { auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isCompileTimeComputable(adaptor.getData())) { Value reshaped = materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData())); rewriter.replaceOp(reshapeOp, reshaped);
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
return success(); return success();
}; };
@@ -12,25 +12,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static Value extractSliceAt(
Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
SmallVector<int64_t> resultShape(inputType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
}
struct Split : OpConversionPattern<ONNXSplitOp> { struct Split : OpConversionPattern<ONNXSplitOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -41,8 +22,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
return failure(); return failure();
int64_t rank = inputType.getRank(); int64_t rank = inputType.getRank();
int64_t axis = normalizeAxis(splitOp.getAxis(), rank); auto axis = normalizeAxisChecked(splitOp.getAxis(), rank);
if (axis < 0 || axis >= rank) if (failed(axis))
return failure(); return failure();
SmallVector<Value> outputs; SmallVector<Value> outputs;
@@ -58,12 +39,13 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return failure(); return failure();
resultTypes.push_back(resultType); resultTypes.push_back(resultType);
sliceSizes.push_back(resultType.getShape()[axis]); sliceSizes.push_back(resultType.getShape()[*axis]);
} }
if (isCompileTimeComputable(adaptor.getInput())) { if (isCompileTimeComputable(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) { for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); outputs.push_back(
extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
offset += sliceSize; offset += sliceSize;
} }
rewriter.replaceOp(splitOp, outputs); rewriter.replaceOp(splitOp, outputs);
@@ -76,7 +58,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
runtimeOutputs.reserve(resultTypes.size()); runtimeOutputs.reserve(resultTypes.size());
int64_t runtimeOffset = 0; int64_t runtimeOffset = 0;
for (int64_t sliceSize : sliceSizes) { for (int64_t sliceSize : sliceSizes) {
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc())); runtimeOutputs.push_back(
extractAxisSlice(rewriter, splitOp.getLoc(), input, *axis, runtimeOffset, sliceSize));
runtimeOffset += sliceSize; runtimeOffset += sliceSize;
} }
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs); spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
@@ -4,6 +4,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -29,22 +30,6 @@ static Value createTransposeInit(Value input,
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult(); return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
} }
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
SmallVector<int64_t> permutation;
if (auto permAttr = transposeOp.getPermAttr()) {
permutation.reserve(permAttr.size());
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
permutation.push_back(attr.getInt());
return permutation;
}
permutation.reserve(inputType.getRank());
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> { struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -56,10 +41,12 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
if (!inputType || !resultType) if (!inputType || !resultType)
return failure(); return failure();
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp); auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc()); if (failed(permutation))
return failure();
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
Value transposed = Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation) linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation)
.getResult()[0]; .getResult()[0];
rewriter.replaceOp(transposeOp, transposed); rewriter.replaceOp(transposeOp, transposed);
return success(); return success();
@@ -40,7 +40,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
continue; continue;
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) { if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder)); mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp));
continue; continue;
} }
@@ -218,7 +218,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
continue; continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) { if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder)); blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantFolder, constantOp));
continue; continue;
} }
@@ -230,8 +230,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
PimMemCopyHostToDevOp::create(rewriter, PimMemCopyHostToDevOp::create(rewriter,
loc, loc,
outputBuffer.getType(), outputBuffer.getType(),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder), getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder), getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
outputBuffer, outputBuffer,
input, input,
getTensorSizeInBytesAttr(rewriter, input)) getTensorSizeInBytesAttr(rewriter, input))
@@ -326,7 +326,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
continue; continue;
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) { if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder)); mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp));
continue; continue;
} }
@@ -370,8 +370,8 @@ static Value emitHostCopy(IRRewriter& rewriter,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder); Value hostTargetOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, hostTargetOffset);
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder); Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
return PimMemCopyDevToHostOp::create(rewriter, return PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
@@ -81,7 +81,7 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter,
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder); auto zeroIndex = getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType))); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>()) if (outputBuffer->getParentOfType<PimCoreBatchOp>())
@@ -333,9 +333,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
rewriter, rewriter,
loc, loc,
tensorType, tensorType,
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder), getOrCreateHostIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
getOrCreateHostIndexConstant( getOrCreateHostIndexConstant(constantFolder,
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder), deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize) ),
deviceTensor, deviceTensor,
inputTensor, inputTensor,
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize))); rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
@@ -619,7 +619,7 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
} }
Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) { Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) {
return getOrCreateHostIndexConstant(anchor, value, state.constantFolder); return getOrCreateHostIndexConstant(state.constantFolder, anchor, value);
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@@ -939,7 +939,7 @@ Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, Arr
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType()); auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
auto attr = DenseIntElementsAttr::get(type, elements); auto attr = DenseIntElementsAttr::get(type, elements);
return getOrCreateHostConstant(anchor, attr, type, state.constantFolder); return getOrCreateHostConstant(state.constantFolder, anchor, attr, type);
} }
bool allEqual(ArrayRef<int64_t> values) { bool allEqual(ArrayRef<int64_t> values) {
@@ -113,8 +113,8 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
rewriter, rewriter,
op->getLoc(), op->getLoc(),
originalType, originalType,
getOrCreateHostIndexConstant(op, 0, constantFolder), getOrCreateHostIndexConstant(constantFolder, op, 0),
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder), getOrCreateHostIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ),
deviceDst, deviceDst,
getGlobalOp.getResult(), getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes))) rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))