This commit is contained in:
@@ -28,7 +28,7 @@ Block* getHostConstantBlock(Operation* anchorOp) {
|
|||||||
return anchorOp->getBlock();
|
return anchorOp->getBlock();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
|
Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
|
||||||
assert(anchorOp && "expected a valid anchor operation");
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||||
for (Operation& op : *hostBlock) {
|
for (Operation& op : *hostBlock) {
|
||||||
@@ -42,7 +42,7 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O
|
|||||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) {
|
Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
|
||||||
assert(anchorOp && "expected a valid anchor operation");
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||||
for (Operation& op : *hostBlock) {
|
for (Operation& op : *hostBlock) {
|
||||||
@@ -57,28 +57,28 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, R
|
|||||||
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
|
||||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
|
||||||
Builder builder(anchorOp->getContext());
|
Builder builder(anchorOp->getContext());
|
||||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() );
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) {
|
Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
|
||||||
Builder builder(anchorOp->getContext());
|
Builder builder(anchorOp->getContext());
|
||||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter);
|
return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||||
Builder builder(anchorOp->getContext());
|
Builder builder(anchorOp->getContext());
|
||||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() );
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||||
Builder builder(anchorOp->getContext());
|
Builder builder(anchorOp->getContext());
|
||||||
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() );
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createAffineApplyOrFoldedConstant(
|
Value createAffineApplyOrFoldedConstant(
|
||||||
@@ -95,7 +95,7 @@ Value createAffineApplyOrFoldedConstant(
|
|||||||
SmallVector<Attribute> foldedResults;
|
SmallVector<Attribute> foldedResults;
|
||||||
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||||
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
||||||
return getOrCreateHostIndexConstant(anchorOp, constantResult.getInt(), rewriter);
|
return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||||
|
|||||||
@@ -10,25 +10,25 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
mlir::Value getOrCreateHostConstant(mlir::OperationFolder& folder,
|
||||||
|
mlir::Operation* anchorOp,
|
||||||
mlir::Attribute value,
|
mlir::Attribute value,
|
||||||
mlir::Type type,
|
mlir::Type type);
|
||||||
mlir::OperationFolder& folder);
|
|
||||||
|
|
||||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
mlir::Value getOrCreateHostConstant(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Operation* anchorOp,
|
||||||
mlir::Attribute value,
|
mlir::Attribute value,
|
||||||
mlir::Type type,
|
mlir::Type type);
|
||||||
mlir::RewriterBase& rewriter);
|
|
||||||
|
|
||||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter);
|
mlir::Value getOrCreateHostIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostI32Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int32_t value);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostI64Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||||
|
|
||||||
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
|
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
Patterns/Tensor/Transpose.cpp
|
Patterns/Tensor/Transpose.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
|
Common/AttributeUtils.cpp
|
||||||
Common/ComputeRegionBuilder.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
|
Common/IndexingUtils.cpp
|
||||||
Common/ShapeTilingUtils.cpp
|
Common/ShapeTilingUtils.cpp
|
||||||
Common/WeightMaterialization.cpp
|
Common/WeightMaterialization.cpp
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
#include "AttributeUtils.hpp"
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
|
||||||
|
|
||||||
|
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
|
||||||
|
return attr ? getI64Attr(*attr, index) : defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
|
||||||
|
llvm::SmallVector<int64_t> values;
|
||||||
|
values.reserve(attr.size());
|
||||||
|
for (Attribute value : attr)
|
||||||
|
values.push_back(cast<IntegerAttr>(value).getInt());
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
|
||||||
|
|
||||||
|
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "AttributeUtils.hpp"
|
||||||
#include "ComputeRegionBuilder.hpp"
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "IndexingUtils.hpp"
|
||||||
#include "ShapeTilingUtils.hpp"
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include "WeightMaterialization.hpp"
|
#include "WeightMaterialization.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|||||||
@@ -7,9 +7,13 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <limits>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -49,6 +53,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
|
|||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||||
|
|
||||||
|
struct SpatComputeBatchBodyArgs {
|
||||||
|
mlir::Value lane;
|
||||||
|
mlir::ValueRange weights;
|
||||||
|
mlir::ValueRange inputs;
|
||||||
|
mlir::ValueRange outputs;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <typename RewriterT>
|
template <typename RewriterT>
|
||||||
@@ -159,6 +170,96 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
int64_t laneCount,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||||
|
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||||
|
|
||||||
|
auto batchOp = spatial::SpatComputeBatch::create(
|
||||||
|
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
|
||||||
|
|
||||||
|
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||||
|
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||||
|
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||||
|
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||||
|
for (mlir::Value weight : weights) {
|
||||||
|
blockArgTypes.push_back(weight.getType());
|
||||||
|
blockArgLocs.push_back(weight.getLoc());
|
||||||
|
}
|
||||||
|
for (mlir::Value input : inputs) {
|
||||||
|
blockArgTypes.push_back(input.getType());
|
||||||
|
blockArgLocs.push_back(input.getLoc());
|
||||||
|
}
|
||||||
|
for (mlir::Type resultType : resultTypes) {
|
||||||
|
blockArgTypes.push_back(resultType);
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* block =
|
||||||
|
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
detail::SpatComputeBatchBodyArgs args {
|
||||||
|
block->getArgument(0),
|
||||||
|
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
||||||
|
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
||||||
|
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())
|
||||||
|
};
|
||||||
|
|
||||||
|
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
std::forward<BodyFn>(body)(args);
|
||||||
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
rewriter.eraseOp(batchOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::Value dest,
|
||||||
|
mlir::ArrayRef<mlir::OpFoldResult> offsets,
|
||||||
|
mlir::ArrayRef<mlir::OpFoldResult> sizes,
|
||||||
|
mlir::ArrayRef<mlir::OpFoldResult> strides) {
|
||||||
|
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||||
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
|
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename BodyFn>
|
||||||
|
mlir::Value materializeOrComputeUnary(mlir::Value input,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
BodyFn&& build) {
|
||||||
|
auto&& buildFn = build;
|
||||||
|
if (isCompileTimeComputable(input))
|
||||||
|
return buildFn(input);
|
||||||
|
|
||||||
|
auto computeOp =
|
||||||
|
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||||
|
mlir::Value result = buildFn(computeInput);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,104 @@
|
|||||||
|
#include "IndexingUtils.hpp"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/APInt.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||||
|
|
||||||
|
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
|
||||||
|
int64_t normalizedAxis = normalizeAxis(axis, rank);
|
||||||
|
if (normalizedAxis < 0 || normalizedAxis >= rank)
|
||||||
|
return failure();
|
||||||
|
return normalizedAxis;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||||
|
|
||||||
|
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||||
|
SmallVector<int64_t> normalizedAxes;
|
||||||
|
if (!axesAttr) {
|
||||||
|
normalizedAxes.reserve(rank);
|
||||||
|
for (int64_t axis = 0; axis < rank; ++axis)
|
||||||
|
normalizedAxes.push_back(axis);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
normalizedAxes.reserve(axesAttr->size());
|
||||||
|
for (Attribute attr : *axesAttr)
|
||||||
|
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
|
||||||
|
llvm::sort(normalizedAxes);
|
||||||
|
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||||
|
}
|
||||||
|
return normalizedAxes;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||||
|
return normalizeAxesImpl(std::optional<ArrayAttr>(axesAttr), rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> normalizeAxes(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||||
|
return normalizeAxesImpl(axesAttr, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||||
|
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
||||||
|
for (int64_t axis : normalizedAxes)
|
||||||
|
if (axis < 0 || axis >= rank)
|
||||||
|
return failure();
|
||||||
|
return normalizedAxes;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) {
|
||||||
|
return normalizeAxesChecked(std::optional<ArrayAttr>(axesAttr), rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||||
|
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||||
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||||
|
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
|
||||||
|
if (multiplier == 0)
|
||||||
|
return getOrCreateHostIndexConstant(rewriter, anchorOp, 0);
|
||||||
|
if (multiplier == 1)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
||||||
|
}
|
||||||
|
|
||||||
|
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||||
|
if (divisor == 1)
|
||||||
|
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||||
|
}
|
||||||
|
|
||||||
|
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||||
|
if (divisor == 1)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(value))
|
||||||
|
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||||
|
return cast<Value>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Interfaces/FoldInterfaces.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
int64_t normalizeAxis(int64_t axis, int64_t rank);
|
||||||
|
|
||||||
|
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
|
||||||
|
|
||||||
|
int64_t normalizeIndex(int64_t index, int64_t dimSize);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> normalizeAxes(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||||
|
|
||||||
|
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank);
|
||||||
|
|
||||||
|
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||||
|
|
||||||
|
mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::AffineExpr expr,
|
||||||
|
mlir::ValueRange operands);
|
||||||
|
|
||||||
|
mlir::Value
|
||||||
|
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
|
||||||
|
|
||||||
|
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||||
|
|
||||||
|
mlir::Value
|
||||||
|
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||||
|
|
||||||
|
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -6,20 +6,21 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
#include "ShapeTilingUtils.hpp"
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "IndexingUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
if (auto attr = dyn_cast<Attribute>(result))
|
return getOrMaterializeIndexValue(rewriter, loc, result);
|
||||||
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
|
|
||||||
return cast<Value>(result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -50,6 +51,84 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
|
|||||||
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); }
|
||||||
|
|
||||||
|
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||||
|
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
|
||||||
|
|
||||||
|
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||||
|
SmallVector<int64_t> permutedShape;
|
||||||
|
permutedShape.reserve(permutation.size());
|
||||||
|
for (int64_t axis : permutation)
|
||||||
|
permutedShape.push_back(shape[axis]);
|
||||||
|
return permutedShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||||
|
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||||
|
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||||
|
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||||
|
return inversePermutation;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||||
|
SmallVector<int64_t> permutation;
|
||||||
|
if (!permAttr) {
|
||||||
|
permutation.reserve(rank);
|
||||||
|
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||||
|
permutation.push_back(dim);
|
||||||
|
return permutation;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
permutation.reserve(permAttr->size());
|
||||||
|
SmallVector<bool> seen(rank, false);
|
||||||
|
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||||
|
int64_t axis = attr.getInt();
|
||||||
|
if (axis < 0 || axis >= rank || seen[axis])
|
||||||
|
return failure();
|
||||||
|
seen[axis] = true;
|
||||||
|
permutation.push_back(axis);
|
||||||
|
}
|
||||||
|
return permutation;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value transposeMaybeInCompute(Value value,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<int64_t> permutation,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto buildTranspose = [&](Value input) -> Value {
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||||
|
};
|
||||||
|
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||||
|
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||||
|
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.reserve(shape.size());
|
||||||
|
for (int64_t dim : shape)
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
|
return sizes;
|
||||||
|
}
|
||||||
|
|
||||||
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
||||||
@@ -88,11 +167,8 @@ SmallVector<Value> sliceTensor(
|
|||||||
assert("Invalid axis" && axis < shape.size());
|
assert("Invalid axis" && axis < shape.size());
|
||||||
|
|
||||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||||
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
|
||||||
SmallVector<OpFoldResult> sizes;
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
|
||||||
sizes.reserve(shape.size());
|
|
||||||
for (const auto size : shape)
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(size));
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||||
|
|
||||||
long length = shape[axis];
|
long length = shape[axis];
|
||||||
@@ -276,4 +352,43 @@ Value materializeContiguousTensorSlice(Value source,
|
|||||||
return buildLoopNest(buildLoopNest, 0, init);
|
return buildLoopNest(buildLoopNest, 0, init);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value extractStaticSlice(PatternRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Value source,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<OpFoldResult> offsets) {
|
||||||
|
return tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
|
||||||
|
getUnitStrides(rewriter, resultType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value extractAxisSlice(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||||
|
resultShape[axis] = size;
|
||||||
|
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||||
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||||
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||||
|
sizes[axis] = rewriter.getIndexAttr(size);
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value insertStaticSlice(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
return tensor::InsertSliceOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
source,
|
||||||
|
dest,
|
||||||
|
offsets,
|
||||||
|
getStaticSizes(rewriter, sourceType.getShape()),
|
||||||
|
getUnitStrides(rewriter, sourceType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/IR/ValueRange.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
@@ -11,6 +12,7 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@@ -109,6 +111,33 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
|||||||
&& lhsType.getShape() == rhsType.getShape();
|
&& lhsType.getShape() == rhsType.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||||
|
|
||||||
|
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
int64_t getStaticShapeElementCount(mlir::RankedTensorType type);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||||
|
|
||||||
|
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||||
|
int64_t rank);
|
||||||
|
|
||||||
|
mlir::Value transposeMaybeInCompute(mlir::Value value,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
mlir::ArrayRef<int64_t> permutation,
|
||||||
|
mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||||
/// at most `sliceSize` elements.
|
/// at most `sliceSize` elements.
|
||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
@@ -148,4 +177,23 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
|||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
|
mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||||
|
|
||||||
|
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
int64_t axis,
|
||||||
|
int64_t offset,
|
||||||
|
int64_t size);
|
||||||
|
|
||||||
|
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::Value dest,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
|
||||||
|
|
||||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||||
if (biasType.getRank() != 1)
|
if (biasType.getRank() != 1)
|
||||||
@@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
|
||||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
|
||||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
|
||||||
|
|
||||||
int64_t padHeightBegin = 0;
|
int64_t padHeightBegin = 0;
|
||||||
int64_t padHeightEnd = 0;
|
int64_t padHeightEnd = 0;
|
||||||
@@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
int64_t padWidthEnd = 0;
|
int64_t padWidthEnd = 0;
|
||||||
|
|
||||||
if (padsAttr) {
|
if (padsAttr) {
|
||||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
padHeightBegin = getI64Attr(*padsAttr, 0);
|
||||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
padWidthBegin = getI64Attr(*padsAttr, 1);
|
||||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
padHeightEnd = getI64Attr(*padsAttr, 2);
|
||||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
padWidthEnd = getI64Attr(*padsAttr, 3);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// Compute padding from auto_pad attribute
|
// Compute padding from auto_pad attribute
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
@@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value,
|
|||||||
ArrayRef<int64_t> permutation,
|
ArrayRef<int64_t> permutation,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (isCompileTimeComputable(value))
|
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
|
||||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
|
||||||
|
|
||||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
|
||||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
||||||
});
|
|
||||||
return computeOp.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) {
|
|
||||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
||||||
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value
|
|
||||||
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
|
||||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
||||||
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
|
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
if (multiplier == 0)
|
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
|
||||||
return createIndexConstant(rewriter, 0);
|
|
||||||
if (multiplier == 1)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
||||||
return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
if (divisor == 1)
|
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
|
||||||
return createIndexConstant(rewriter, 0);
|
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
||||||
return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter
|
|||||||
static Value createGemmBatchKOffset(
|
static Value createGemmBatchKOffset(
|
||||||
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
if (numKSlices == 1)
|
if (numKSlices == 1)
|
||||||
return createIndexConstant(rewriter, 0);
|
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
MLIRContext* context = rewriter.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
return createAffineApply(
|
return createAffineApplyOrConstant(
|
||||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane,
|
|||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (numOutHSlices == 1)
|
if (numOutHSlices == 1)
|
||||||
return createIndexConstant(rewriter, 0);
|
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
MLIRContext* context = rewriter.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
return createAffineApply(
|
return createAffineApplyOrConstant(
|
||||||
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
|||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
auto batchOp = createSpatComputeBatch(
|
||||||
loc,
|
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
TypeRange {partialPiecesType},
|
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||||
ValueRange {b},
|
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||||
ValueRange {a});
|
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType};
|
auto aTileType =
|
||||||
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||||
Block* body =
|
|
||||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
||||||
rewriter.setInsertionPointToEnd(body);
|
|
||||||
|
|
||||||
auto lane = batchOp.getLaneArgument();
|
|
||||||
auto weight = batchOp.getWeightArgument(0);
|
|
||||||
auto input = batchOp.getInputArgument(0);
|
|
||||||
auto output = batchOp.getOutputArgument(0);
|
|
||||||
assert(lane && weight && input && output && "malformed Gemm compute_batch body");
|
|
||||||
|
|
||||||
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
|
|
||||||
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
|
|
||||||
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
|
||||||
|
|
||||||
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
|
||||||
auto bTileType = RankedTensorType::get(
|
auto bTileType = RankedTensorType::get(
|
||||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||||
paddedBType.getElementType());
|
paddedBType.getElementType());
|
||||||
auto pieceType =
|
auto pieceType =
|
||||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||||
Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc);
|
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
|
||||||
|
|
||||||
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
||||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||||
Value bTile =
|
Value bTile =
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult();
|
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||||
|
|
||||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
|
||||||
SmallVector<OpFoldResult> pieceOffsets {*lane, rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides);
|
createParallelInsertSliceIntoBatchOutput(
|
||||||
|
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
});
|
||||||
return batchOp;
|
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
|
||||||
|
return *batchOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createDynamicGemmBatchRow(
|
static Value createDynamicGemmBatchRow(
|
||||||
@@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow(
|
|||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
MLIRContext* context = rewriter.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createDynamicGemmBatchColumn(
|
static Value createDynamicGemmBatchColumn(
|
||||||
@@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
|||||||
const int64_t numOutCols = outType.getDimSize(1);
|
const int64_t numOutCols = outType.getDimSize(1);
|
||||||
const int64_t reductionSize = aType.getDimSize(1);
|
const int64_t reductionSize = aType.getDimSize(1);
|
||||||
const int64_t laneCount = numOutRows * numOutCols;
|
const int64_t laneCount = numOutRows * numOutCols;
|
||||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
auto batchOp = createSpatComputeBatch(
|
||||||
loc,
|
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
TypeRange {scalarPiecesType},
|
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
|
||||||
ValueRange {},
|
|
||||||
ValueRange {a, b});
|
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
|
|
||||||
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
|
||||||
Block* body =
|
|
||||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
||||||
rewriter.setInsertionPointToEnd(body);
|
|
||||||
|
|
||||||
auto lane = batchOp.getLaneArgument();
|
|
||||||
auto inputA = batchOp.getInputArgument(0);
|
|
||||||
auto inputB = batchOp.getInputArgument(1);
|
|
||||||
auto output = batchOp.getOutputArgument(0);
|
|
||||||
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
|
|
||||||
|
|
||||||
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
|
|
||||||
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
|
|
||||||
|
|
||||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
|
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||||
Value bVector = bAlreadyTransposed
|
Value bVector = bAlreadyTransposed
|
||||||
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
|
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||||
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
|
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
|
|
||||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
|
||||||
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
|
createParallelInsertSliceIntoBatchOutput(
|
||||||
|
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
});
|
||||||
return batchOp;
|
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
|
||||||
|
return *batchOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||||
@@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
|||||||
Value biasArg = bias ? blockArgs[1] : Value();
|
Value biasArg = bias ? blockArgs[1] : Value();
|
||||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||||
Value c0 = createIndexConstant(rewriter, 0);
|
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
Value c1 = createIndexConstant(rewriter, 1);
|
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||||
Value cLaneCount = createIndexConstant(rewriter, laneCount);
|
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||||
rewriter.setInsertionPointToStart(loop.getBody());
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
@@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice,
|
|||||||
Location loc) {
|
Location loc) {
|
||||||
MLIRContext* context = rewriter.getContext();
|
MLIRContext* context = rewriter.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
return createAffineApplyOrConstant(
|
||||||
|
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value extractReductionPiece(Value partialPiecesArg,
|
static Value extractReductionPiece(Value partialPiecesArg,
|
||||||
@@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
|||||||
|
|
||||||
Value paddedOutput = outputInit;
|
Value paddedOutput = outputInit;
|
||||||
if (numOutHSlices == 1) {
|
if (numOutHSlices == 1) {
|
||||||
Value hSlice = createIndexConstant(rewriter, 0);
|
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
Value c0 = createIndexConstant(rewriter, 0);
|
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
Value c1 = createIndexConstant(rewriter, 1);
|
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||||
Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices);
|
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
||||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||||
|
|
||||||
|
|||||||
@@ -19,14 +19,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|
||||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
|
||||||
}
|
|
||||||
|
|
||||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
|
||||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||||
ArrayRef<int64_t> rhsBatchShape) {
|
ArrayRef<int64_t> rhsBatchShape) {
|
||||||
if (lhsBatchShape.empty())
|
if (lhsBatchShape.empty())
|
||||||
@@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
|
|||||||
auto buildCollapsed = [&](Value input) -> Value {
|
auto buildCollapsed = [&](Value input) -> Value {
|
||||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||||
};
|
};
|
||||||
|
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
|
||||||
if (isCompileTimeComputable(value))
|
|
||||||
return buildCollapsed(value);
|
|
||||||
|
|
||||||
auto collapseCompute =
|
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
|
||||||
});
|
|
||||||
return collapseCompute.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
@@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
|
|||||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
|
||||||
auto expandCompute =
|
auto buildExpanded = [&](Value input) -> Value {
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
|
||||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
};
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||||
});
|
|
||||||
return expandCompute.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value extractBatchMatrix(Value value,
|
static Value extractBatchMatrix(Value value,
|
||||||
@@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value,
|
|||||||
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> sizes = {
|
SmallVector<OpFoldResult> sizes = {
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
|
||||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||||
auto buildMatrix = [&](Value input) -> Value {
|
auto buildMatrix = [&](Value input) -> Value {
|
||||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||||
@@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value,
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
if (isCompileTimeComputable(value))
|
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
|
||||||
return buildMatrix(value);
|
|
||||||
|
|
||||||
auto batchMatrixCompute =
|
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
|
||||||
});
|
|
||||||
return batchMatrixCompute.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
@@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
|||||||
perm = {0, 2, 1};
|
perm = {0, 2, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto buildTranspose = [&](Value input) -> Value {
|
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
||||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
|
||||||
};
|
|
||||||
|
|
||||||
if (isCompileTimeComputable(value))
|
|
||||||
return buildTranspose(value);
|
|
||||||
|
|
||||||
auto transposeCompute =
|
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
|
||||||
});
|
|
||||||
return transposeCompute.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
@@ -166,7 +130,8 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
|||||||
perm = {0, 2, 1};
|
perm = {0, 2, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
});
|
});
|
||||||
@@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||||
return failure();
|
return failure();
|
||||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||||
|| !haveStaticPositiveShape(outType.getShape()))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
@@ -16,26 +18,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
|
||||||
SmallVector<int64_t> normalizedAxes;
|
|
||||||
if (!axesAttr) {
|
|
||||||
normalizedAxes.reserve(rank);
|
|
||||||
for (int64_t axis = 0; axis < rank; axis++)
|
|
||||||
normalizedAxes.push_back(axis);
|
|
||||||
return normalizedAxes;
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedAxes.reserve(axesAttr.size());
|
|
||||||
for (Attribute attr : axesAttr) {
|
|
||||||
int64_t axis = cast<IntegerAttr>(attr).getInt();
|
|
||||||
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::sort(normalizedAxes);
|
|
||||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
|
||||||
return normalizedAxes;
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
||||||
SmallVector<bool> reducedAxes(rank, false);
|
SmallVector<bool> reducedAxes(rank, false);
|
||||||
for (int64_t axis : axes) {
|
for (int64_t axis : axes) {
|
||||||
@@ -50,6 +32,181 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
|
|||||||
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
|
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
shape.reserve(inputType.getRank());
|
||||||
|
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||||
|
shape.push_back(isReduced ? 1 : dim);
|
||||||
|
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
|
||||||
|
}
|
||||||
|
|
||||||
|
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||||
|
if (!isReduced)
|
||||||
|
shape.push_back(dim);
|
||||||
|
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
|
||||||
|
}
|
||||||
|
|
||||||
|
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
shape.reserve(inputType.getRank());
|
||||||
|
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||||
|
shape.push_back(isReduced ? dim : 1);
|
||||||
|
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
|
||||||
|
}
|
||||||
|
|
||||||
|
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
|
||||||
|
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
|
||||||
|
shape.front() = laneCount;
|
||||||
|
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
|
||||||
|
SmallVector<int64_t> keptAxes;
|
||||||
|
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
|
||||||
|
if (!isReduced)
|
||||||
|
keptAxes.push_back(static_cast<int64_t>(axis));
|
||||||
|
return keptAxes;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value computeLaneIndex(Value lane,
|
||||||
|
int64_t stride,
|
||||||
|
int64_t dimSize,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (dimSize == 1)
|
||||||
|
return arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
AffineExpr expr = d0;
|
||||||
|
if (stride != 1)
|
||||||
|
expr = expr.floorDiv(stride);
|
||||||
|
if (dimSize != 1)
|
||||||
|
expr = expr % dimSize;
|
||||||
|
return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane});
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
||||||
|
ArrayRef<bool> reducedAxes,
|
||||||
|
RankedTensorType batchType,
|
||||||
|
RankedTensorType leafType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
auto sliceType = getReducedSliceType(inputType, reducedAxes);
|
||||||
|
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
|
||||||
|
|
||||||
|
int64_t laneCount = 1;
|
||||||
|
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
|
||||||
|
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
|
||||||
|
keptAxisStrides[index] = laneCount;
|
||||||
|
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
|
||||||
|
if (dimSize <= 0)
|
||||||
|
return failure();
|
||||||
|
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
|
||||||
|
return failure();
|
||||||
|
laneCount *= dimSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> sliceOffsets;
|
||||||
|
SmallVector<OpFoldResult> sliceSizes;
|
||||||
|
SmallVector<OpFoldResult> insertOffsets;
|
||||||
|
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
|
||||||
|
sliceOffsets.reserve(inputType.getRank());
|
||||||
|
sliceSizes.reserve(inputType.getRank());
|
||||||
|
insertOffsets.reserve(inputType.getRank());
|
||||||
|
|
||||||
|
auto batchOp = createSpatComputeBatch(
|
||||||
|
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
|
size_t keptAxisIndex = 0;
|
||||||
|
sliceOffsets.clear();
|
||||||
|
sliceSizes.clear();
|
||||||
|
insertOffsets.clear();
|
||||||
|
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||||
|
if (isReduced) {
|
||||||
|
sliceOffsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value axisIndex =
|
||||||
|
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
|
||||||
|
++keptAxisIndex;
|
||||||
|
sliceOffsets.push_back(axisIndex);
|
||||||
|
sliceSizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
insertOffsets.push_back(args.lane);
|
||||||
|
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
|
||||||
|
|
||||||
|
Value slice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
|
||||||
|
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
|
||||||
|
createParallelInsertSliceIntoBatchOutput(
|
||||||
|
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
|
||||||
|
});
|
||||||
|
if (failed(batchOp))
|
||||||
|
return failure();
|
||||||
|
return (*batchOp).getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
|
||||||
|
RankedTensorType keepdimsType,
|
||||||
|
RankedTensorType compactKeptType,
|
||||||
|
ArrayRef<bool> reducedAxes,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto batchType = cast<RankedTensorType>(batchValue.getType());
|
||||||
|
if (batchType == keepdimsType)
|
||||||
|
return batchValue;
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> collapseToFlat {{}};
|
||||||
|
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
|
||||||
|
collapseToFlat.front().push_back(axis);
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> expandFlatToCompact(1);
|
||||||
|
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
|
||||||
|
expandFlatToCompact.front().push_back(axis);
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> expandCompactToKeepdims;
|
||||||
|
ReassociationIndices pendingLeadingReducedAxes;
|
||||||
|
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||||
|
if (isReduced) {
|
||||||
|
if (expandCompactToKeepdims.empty())
|
||||||
|
pendingLeadingReducedAxes.push_back(axis);
|
||||||
|
else
|
||||||
|
expandCompactToKeepdims.back().push_back(axis);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
expandCompactToKeepdims.emplace_back();
|
||||||
|
auto& group = expandCompactToKeepdims.back();
|
||||||
|
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
|
||||||
|
pendingLeadingReducedAxes.clear();
|
||||||
|
group.push_back(axis);
|
||||||
|
}
|
||||||
|
if (!pendingLeadingReducedAxes.empty())
|
||||||
|
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
|
||||||
|
|
||||||
|
auto reshapeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
|
||||||
|
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
|
||||||
|
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
|
||||||
|
Value compact = flat;
|
||||||
|
if (compactKeptType != flatType)
|
||||||
|
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
|
||||||
|
Value keepdims = compact;
|
||||||
|
if (keepdimsType != compactKeptType)
|
||||||
|
keepdims =
|
||||||
|
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
|
||||||
|
});
|
||||||
|
return reshapeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
|
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
|
||||||
SmallVector<ReassociationIndices> reassociation;
|
SmallVector<ReassociationIndices> reassociation;
|
||||||
ReassociationIndices currentGroup;
|
ReassociationIndices currentGroup;
|
||||||
@@ -72,56 +229,6 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
|
|||||||
return reassociation;
|
return reassociation;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
|
||||||
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
constexpr size_t numInputs = 1;
|
|
||||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
|
|
||||||
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
|
|
||||||
});
|
|
||||||
return computeOp.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
|
||||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
|
||||||
int64_t concatDimSize = 0;
|
|
||||||
for (Value input : inputs)
|
|
||||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
|
||||||
outputShape[axis] = concatDimSize;
|
|
||||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
||||||
|
|
||||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
|
||||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
|
||||||
|
|
||||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
|
||||||
});
|
|
||||||
return concatCompute.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildReduceMeanKeepdims(Value input,
|
|
||||||
ArrayRef<bool> reducedAxes,
|
|
||||||
int64_t axis,
|
|
||||||
RankedTensorType leafType,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
|
|
||||||
if (axis == rank)
|
|
||||||
return createAverageCompute(input, leafType, rewriter, loc);
|
|
||||||
|
|
||||||
if (reducedAxes[axis])
|
|
||||||
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
|
||||||
SmallVector<Value> reducedSlices;
|
|
||||||
reducedSlices.reserve(slices.size());
|
|
||||||
for (Value slice : slices)
|
|
||||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
|
||||||
|
|
||||||
return concatValues(reducedSlices, axis, rewriter, loc);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||||
RankedTensorType resultType,
|
RankedTensorType resultType,
|
||||||
ArrayRef<bool> reducedAxes,
|
ArrayRef<bool> reducedAxes,
|
||||||
@@ -156,16 +263,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
|||||||
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
||||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
if (inputType.getRank() == 0) {
|
||||||
|
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
|
if (failed(axes))
|
||||||
|
return failure();
|
||||||
|
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
|
||||||
if (reducedAxes.empty() && inputType.getRank() != 0)
|
if (reducedAxes.empty() && inputType.getRank() != 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = reduceMeanOp.getLoc();
|
Location loc = reduceMeanOp.getLoc();
|
||||||
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
||||||
|
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
|
||||||
|
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
|
||||||
|
int64_t laneCount = 1;
|
||||||
|
for (int64_t dim : compactKeptType.getShape())
|
||||||
|
laneCount *= dim;
|
||||||
|
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
|
||||||
|
|
||||||
|
auto lanePackedKeepdims =
|
||||||
|
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
|
||||||
|
if (failed(lanePackedKeepdims))
|
||||||
|
return failure();
|
||||||
Value reducedKeepdims =
|
Value reducedKeepdims =
|
||||||
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
|
||||||
|
|
||||||
if (reduceMeanOp.getKeepdims() != 0) {
|
if (reduceMeanOp.getKeepdims() != 0) {
|
||||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||||
|
|||||||
@@ -23,28 +23,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename ArrayAttrT>
|
|
||||||
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
|
||||||
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ArrayAttrT>
|
|
||||||
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
|
||||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||||
|
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
|
||||||
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
sizes.reserve(tileType.getRank());
|
|
||||||
for (int64_t dimSize : tileType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
|
||||||
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
|
||||||
|
|
||||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
@@ -197,12 +179,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
const int64_t inputWidth = xType.getDimSize(3);
|
const int64_t inputWidth = xType.getDimSize(3);
|
||||||
const int64_t outputHeight = outType.getDimSize(2);
|
const int64_t outputHeight = outType.getDimSize(2);
|
||||||
const int64_t outputWidth = outType.getDimSize(3);
|
const int64_t outputWidth = outType.getDimSize(3);
|
||||||
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
|
||||||
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
|
||||||
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
|
||||||
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
|
||||||
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
|
||||||
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
|
||||||
|
|
||||||
int64_t padTop = 0;
|
int64_t padTop = 0;
|
||||||
int64_t padLeft = 0;
|
int64_t padLeft = 0;
|
||||||
@@ -212,10 +194,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
if (auto padsAttr = poolOp.getPads()) {
|
if (auto padsAttr = poolOp.getPads()) {
|
||||||
if (padsAttr->size() != 4)
|
if (padsAttr->size() != 4)
|
||||||
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||||
padTop = getI64(*padsAttr, 0);
|
padTop = getI64Attr(*padsAttr, 0);
|
||||||
padLeft = getI64(*padsAttr, 1);
|
padLeft = getI64Attr(*padsAttr, 1);
|
||||||
padBottom = getI64(*padsAttr, 2);
|
padBottom = getI64Attr(*padsAttr, 2);
|
||||||
padRight = getI64(*padsAttr, 3);
|
padRight = getI64Attr(*padsAttr, 3);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
StringRef autoPad = poolOp.getAutoPad();
|
StringRef autoPad = poolOp.getAutoPad();
|
||||||
|
|||||||
@@ -13,16 +13,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
|
||||||
|
|
||||||
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
|
||||||
SmallVector<int64_t> permutedShape;
|
|
||||||
permutedShape.reserve(permutation.size());
|
|
||||||
for (int64_t axis : permutation)
|
|
||||||
permutedShape.push_back(shape[axis]);
|
|
||||||
return permutedShape;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildLoopSoftmaxSlice(Value input,
|
static Value buildLoopSoftmaxSlice(Value input,
|
||||||
Value accumulator,
|
Value accumulator,
|
||||||
RankedTensorType inputType,
|
RankedTensorType inputType,
|
||||||
@@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input,
|
|||||||
|
|
||||||
SmallVector<OpFoldResult> offsets;
|
SmallVector<OpFoldResult> offsets;
|
||||||
SmallVector<OpFoldResult> sizes;
|
SmallVector<OpFoldResult> sizes;
|
||||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
||||||
offsets.reserve(rank);
|
offsets.reserve(rank);
|
||||||
sizes.reserve(rank);
|
sizes.reserve(rank);
|
||||||
|
|
||||||
@@ -110,44 +100,31 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
if (!inputType || !inputType.hasStaticShape())
|
if (!inputType || !inputType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
|
auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
|
||||||
if (axis < 0 || axis >= inputType.getRank())
|
if (failed(axis))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
Value result;
|
Value result;
|
||||||
if (axis == inputType.getRank() - 1) {
|
if (*axis == inputType.getRank() - 1) {
|
||||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
SmallVector<int64_t> permutation;
|
SmallVector<int64_t> permutation;
|
||||||
permutation.reserve(inputType.getRank());
|
permutation.reserve(inputType.getRank());
|
||||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
||||||
if (dim != axis)
|
if (dim != *axis)
|
||||||
permutation.push_back(dim);
|
permutation.push_back(dim);
|
||||||
permutation.push_back(axis);
|
permutation.push_back(*axis);
|
||||||
|
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
|
||||||
SmallVector<int64_t> inversePermutation(inputType.getRank());
|
|
||||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
|
||||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
|
||||||
|
|
||||||
auto transposedType = RankedTensorType::get(
|
auto transposedType = RankedTensorType::get(
|
||||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||||
auto preTransposeCompute =
|
Value transposedInput =
|
||||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
||||||
Value transposed = ONNXTransposeOp::create(
|
|
||||||
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
|
||||||
});
|
|
||||||
Value transposedInput = preTransposeCompute.getResult(0);
|
|
||||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||||
auto postTransposeCompute =
|
result = transposeMaybeInCompute(
|
||||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||||
Value transposed = ONNXTransposeOp::create(
|
|
||||||
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
|
||||||
});
|
|
||||||
result = postTransposeCompute.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(softmaxOp, result);
|
rewriter.replaceOp(softmaxOp, result);
|
||||||
|
|||||||
@@ -36,6 +36,14 @@ static bool isDirectConstantValue(Value value) {
|
|||||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct PromotedOperands {
|
||||||
|
SmallVector<bool> promoteInput;
|
||||||
|
SmallVector<Value> newWeights;
|
||||||
|
SmallVector<Value> newInputs;
|
||||||
|
SmallVector<Type> newInputTypes;
|
||||||
|
SmallVector<Location> newInputLocs;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ComputeOpTy>
|
template <typename ComputeOpTy>
|
||||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
@@ -48,60 +56,91 @@ static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
|
||||||
|
PromotedOperands promoted;
|
||||||
|
promoted.promoteInput.assign(compute.getInputs().size(), false);
|
||||||
|
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
|
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||||
|
promoted.newInputs.reserve(compute.getInputs().size());
|
||||||
|
promoted.newInputTypes.reserve(compute.getInputs().size());
|
||||||
|
promoted.newInputLocs.reserve(compute.getInputs().size());
|
||||||
|
|
||||||
|
bool needsRewrite = false;
|
||||||
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
if (!isWeightLikeComputeOperand(input))
|
||||||
|
goto keep_input;
|
||||||
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||||
|
goto keep_input;
|
||||||
|
promoted.promoteInput[inputIdx] = true;
|
||||||
|
promoted.newWeights.push_back(input);
|
||||||
|
needsRewrite = true;
|
||||||
|
continue;
|
||||||
|
|
||||||
|
keep_input:
|
||||||
|
promoted.newInputs.push_back(input);
|
||||||
|
promoted.newInputTypes.push_back(input.getType());
|
||||||
|
promoted.newInputLocs.push_back(input.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!needsRewrite)
|
||||||
|
return failure();
|
||||||
|
return promoted;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
||||||
|
const PromotedOperands& promoted,
|
||||||
|
IRRewriter& bodyRewriter,
|
||||||
|
IRMapping& mapper,
|
||||||
|
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
|
size_t newInputIdx = 0;
|
||||||
|
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||||
|
if (!oldArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
|
||||||
|
if (!promoted.promoteInput[oldInputIdx]) {
|
||||||
|
auto newInputArg = getNewInputArg(newInputIdx++);
|
||||||
|
if (!newInputArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
|
||||||
|
mapper.map(*oldArg, *newInputArg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||||
|
if (failed(clonedValue))
|
||||||
|
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||||
|
mapper.map(*oldArg, *clonedValue);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
auto promoted = computePromotedOperands(compute);
|
||||||
bool needsRewrite = false;
|
if (failed(promoted))
|
||||||
Block& oldBlock = compute.getBody().front();
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
|
||||||
continue;
|
|
||||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
|
||||||
continue;
|
|
||||||
promoteInput[inputIdx] = true;
|
|
||||||
needsRewrite = true;
|
|
||||||
}
|
|
||||||
if (!needsRewrite)
|
|
||||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||||
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
rewriter.setInsertionPointAfter(compute);
|
||||||
|
|
||||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
|
||||||
SmallVector<Value> newInputs;
|
|
||||||
SmallVector<Type> newInputTypes;
|
|
||||||
SmallVector<Location> newInputLocs;
|
|
||||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
|
||||||
newInputs.reserve(compute.getInputs().size());
|
|
||||||
newInputTypes.reserve(compute.getInputs().size());
|
|
||||||
newInputLocs.reserve(compute.getInputs().size());
|
|
||||||
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (promoteInput[inputIdx]) {
|
|
||||||
newWeights.push_back(input);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
newInputs.push_back(input);
|
|
||||||
newInputTypes.push_back(input.getType());
|
|
||||||
newInputLocs.push_back(input.getLoc());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||||
SmallVector<Type> newBlockArgTypes;
|
SmallVector<Type> newBlockArgTypes;
|
||||||
SmallVector<Location> newBlockArgLocs;
|
SmallVector<Location> newBlockArgLocs;
|
||||||
for (Value weight : newWeights) {
|
for (Value weight : promoted->newWeights) {
|
||||||
newBlockArgTypes.push_back(weight.getType());
|
newBlockArgTypes.push_back(weight.getType());
|
||||||
newBlockArgLocs.push_back(weight.getLoc());
|
newBlockArgLocs.push_back(weight.getLoc());
|
||||||
}
|
}
|
||||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
||||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
||||||
auto* newBlock = rewriter.createBlock(
|
auto* newBlock = rewriter.createBlock(
|
||||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
IRRewriter bodyRewriter(rewriter.getContext());
|
IRRewriter bodyRewriter(rewriter.getContext());
|
||||||
@@ -115,24 +154,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||||
mapper.map(*oldWeightArg, *newWeightArg);
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
}
|
}
|
||||||
size_t newInputIdx = 0;
|
if (failed(mapPromotedInputArguments(
|
||||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
||||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
return failure();
|
||||||
if (!oldArg)
|
|
||||||
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
|
||||||
if (!promoteInput[oldInputIdx]) {
|
|
||||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
|
||||||
if (!newInputArg)
|
|
||||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
|
||||||
mapper.map(*oldArg, *newInputArg);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
|
||||||
if (failed(clonedValue))
|
|
||||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
|
||||||
mapper.map(*oldArg, *clonedValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (Operation& op : oldBlock.without_terminator())
|
for (Operation& op : oldBlock.without_terminator())
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
@@ -156,63 +180,35 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
auto promoted = computePromotedOperands(compute);
|
||||||
bool needsRewrite = false;
|
if (failed(promoted))
|
||||||
Block& oldBlock = compute.getBody().front();
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
|
||||||
continue;
|
|
||||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
|
||||||
continue;
|
|
||||||
promoteInput[inputIdx] = true;
|
|
||||||
needsRewrite = true;
|
|
||||||
}
|
|
||||||
if (!needsRewrite)
|
|
||||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||||
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
rewriter.setInsertionPointAfter(compute);
|
||||||
|
|
||||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
|
||||||
SmallVector<Value> newInputs;
|
|
||||||
SmallVector<Type> newInputTypes;
|
|
||||||
SmallVector<Location> newInputLocs;
|
|
||||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
|
||||||
newInputs.reserve(compute.getInputs().size());
|
|
||||||
newInputTypes.reserve(compute.getInputs().size());
|
|
||||||
newInputLocs.reserve(compute.getInputs().size());
|
|
||||||
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (promoteInput[inputIdx]) {
|
|
||||||
newWeights.push_back(input);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
newInputs.push_back(input);
|
|
||||||
newInputTypes.push_back(input.getType());
|
|
||||||
newInputLocs.push_back(input.getLoc());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatComputeBatch::create(rewriter,
|
spatial::SpatComputeBatch::create(rewriter,
|
||||||
compute.getLoc(),
|
compute.getLoc(),
|
||||||
compute.getResultTypes(),
|
compute.getResultTypes(),
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||||
newWeights,
|
promoted->newWeights,
|
||||||
newInputs);
|
promoted->newInputs);
|
||||||
auto laneArg = compute.getLaneArgument();
|
auto laneArg = compute.getLaneArgument();
|
||||||
if (!laneArg)
|
if (!laneArg)
|
||||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||||
SmallVector<Type> newBlockArgTypes;
|
SmallVector<Type> newBlockArgTypes;
|
||||||
SmallVector<Location> newBlockArgLocs;
|
SmallVector<Location> newBlockArgLocs;
|
||||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults());
|
||||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
||||||
newBlockArgTypes.push_back(laneArg->getType());
|
newBlockArgTypes.push_back(laneArg->getType());
|
||||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||||
for (Value weight : newWeights) {
|
for (Value weight : promoted->newWeights) {
|
||||||
newBlockArgTypes.push_back(weight.getType());
|
newBlockArgTypes.push_back(weight.getType());
|
||||||
newBlockArgLocs.push_back(weight.getLoc());
|
newBlockArgLocs.push_back(weight.getLoc());
|
||||||
}
|
}
|
||||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
||||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
||||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||||
if (!outputArg)
|
if (!outputArg)
|
||||||
@@ -224,7 +220,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
auto* newBlock = rewriter.createBlock(
|
auto* newBlock = rewriter.createBlock(
|
||||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
IRRewriter bodyRewriter(rewriter.getContext());
|
IRRewriter bodyRewriter(rewriter.getContext());
|
||||||
@@ -242,29 +238,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||||
mapper.map(*oldWeightArg, *newWeightArg);
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
}
|
}
|
||||||
size_t newInputIdx = 0;
|
if (failed(mapPromotedInputArguments(
|
||||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
||||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
return failure();
|
||||||
if (!oldArg)
|
|
||||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
|
||||||
if (!promoteInput[oldInputIdx]) {
|
|
||||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
|
||||||
if (!newInputArg)
|
|
||||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
|
||||||
mapper.map(*oldArg, *newInputArg);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
|
||||||
if (failed(clonedValue))
|
|
||||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
|
||||||
mapper.map(*oldArg, *clonedValue);
|
|
||||||
}
|
|
||||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||||
if (!outputArg)
|
if (!outputArg)
|
||||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||||
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
mapper.map(*outputArg,
|
||||||
|
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Operation& op : oldBlock)
|
for (Operation& op : oldBlock)
|
||||||
|
|||||||
@@ -15,24 +15,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
|
||||||
|
|
||||||
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
|
||||||
|
|
||||||
static Value
|
|
||||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
|
||||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
|
||||||
sizes.reserve(inputType.getRank());
|
|
||||||
for (int64_t dim : inputType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(1);
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value concatGatherSlices(Value data,
|
static Value concatGatherSlices(Value data,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
ArrayRef<int64_t> indices,
|
ArrayRef<int64_t> indices,
|
||||||
@@ -45,7 +27,7 @@ static Value concatGatherSlices(Value data,
|
|||||||
int64_t normalizedIndex = normalizeIndex(index, axisDim);
|
int64_t normalizedIndex = normalizeIndex(index, axisDim);
|
||||||
if (normalizedIndex < 0 || normalizedIndex >= axisDim)
|
if (normalizedIndex < 0 || normalizedIndex >= axisDim)
|
||||||
return {};
|
return {};
|
||||||
slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc));
|
slices.push_back(extractAxisSlice(rewriter, loc, data, axis, normalizedIndex, /*size=*/1));
|
||||||
}
|
}
|
||||||
if (slices.empty())
|
if (slices.empty())
|
||||||
return {};
|
return {};
|
||||||
@@ -96,11 +78,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
int64_t rank = dataType.getRank();
|
int64_t rank = dataType.getRank();
|
||||||
int64_t axis = normalizeAxis(gatherOp.getAxis(), rank);
|
auto axis = normalizeAxisChecked(gatherOp.getAxis(), rank);
|
||||||
if (axis < 0 || axis >= rank)
|
if (failed(axis))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
int64_t axisDim = dataType.getShape()[axis];
|
int64_t axisDim = dataType.getShape()[*axis];
|
||||||
if (axisDim <= 0)
|
if (axisDim <= 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -116,7 +98,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
|||||||
[&](Value data) -> LogicalResult {
|
[&](Value data) -> LogicalResult {
|
||||||
Value result;
|
Value result;
|
||||||
if (indicesType.getRank() == 1) {
|
if (indicesType.getRank() == 1) {
|
||||||
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
|
result = concatGatherSlices(data, *axis, flatIndices, axisDim, rewriter, loc);
|
||||||
}
|
}
|
||||||
else if (indicesType.getRank() == 2) {
|
else if (indicesType.getRank() == 2) {
|
||||||
int64_t rowCount = indicesType.getShape()[0];
|
int64_t rowCount = indicesType.getShape()[0];
|
||||||
@@ -125,12 +107,13 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
|||||||
rows.reserve(rowCount);
|
rows.reserve(rowCount);
|
||||||
for (int64_t row = 0; row < rowCount; ++row) {
|
for (int64_t row = 0; row < rowCount; ++row) {
|
||||||
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
|
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
|
||||||
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc);
|
Value gatheredRow =
|
||||||
|
concatGatherSlices(data, *axis, rowIndices, axisDim, rewriter, loc);
|
||||||
if (!gatheredRow)
|
if (!gatheredRow)
|
||||||
return failure();
|
return failure();
|
||||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
rows.push_back(addLeadingGatherDim(gatheredRow, *axis, rewriter, loc));
|
||||||
}
|
}
|
||||||
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
result = createSpatConcat(rewriter, loc, /*axis=*/*axis, rows);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -14,10 +14,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|
||||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||||
ArrayRef<int64_t> resultShape,
|
ArrayRef<int64_t> resultShape,
|
||||||
SmallVector<ReassociationIndices>& reassociation) {
|
SmallVector<ReassociationIndices>& reassociation) {
|
||||||
@@ -106,7 +102,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (sourceType == resultType) {
|
if (sourceType == resultType) {
|
||||||
@@ -115,17 +111,8 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||||
if (isCompileTimeComputable(adaptor.getData())) {
|
Value reshaped = materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
|
||||||
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
rewriter.replaceOp(reshapeOp, reshaped);
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto computeOp = createSpatCompute<1>(
|
|
||||||
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
|
||||||
Value reshaped = buildReshape(data);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
|
|
||||||
});
|
|
||||||
rewriter.replaceOp(reshapeOp, computeOp.getResults());
|
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -12,25 +12,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
|
||||||
|
|
||||||
static Value extractSliceAt(
|
|
||||||
Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
|
||||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
|
||||||
sizes.reserve(inputType.getRank());
|
|
||||||
for (int64_t dim : inputType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(size);
|
|
||||||
SmallVector<int64_t> resultShape(inputType.getShape());
|
|
||||||
resultShape[axis] = size;
|
|
||||||
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -41,8 +22,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
int64_t rank = inputType.getRank();
|
int64_t rank = inputType.getRank();
|
||||||
int64_t axis = normalizeAxis(splitOp.getAxis(), rank);
|
auto axis = normalizeAxisChecked(splitOp.getAxis(), rank);
|
||||||
if (axis < 0 || axis >= rank)
|
if (failed(axis))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<Value> outputs;
|
SmallVector<Value> outputs;
|
||||||
@@ -58,12 +39,13 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
resultTypes.push_back(resultType);
|
resultTypes.push_back(resultType);
|
||||||
sliceSizes.push_back(resultType.getShape()[axis]);
|
sliceSizes.push_back(resultType.getShape()[*axis]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isCompileTimeComputable(adaptor.getInput())) {
|
if (isCompileTimeComputable(adaptor.getInput())) {
|
||||||
for (int64_t sliceSize : sliceSizes) {
|
for (int64_t sliceSize : sliceSizes) {
|
||||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
outputs.push_back(
|
||||||
|
extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
|
||||||
offset += sliceSize;
|
offset += sliceSize;
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(splitOp, outputs);
|
rewriter.replaceOp(splitOp, outputs);
|
||||||
@@ -76,7 +58,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
runtimeOutputs.reserve(resultTypes.size());
|
runtimeOutputs.reserve(resultTypes.size());
|
||||||
int64_t runtimeOffset = 0;
|
int64_t runtimeOffset = 0;
|
||||||
for (int64_t sliceSize : sliceSizes) {
|
for (int64_t sliceSize : sliceSizes) {
|
||||||
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
|
runtimeOutputs.push_back(
|
||||||
|
extractAxisSlice(rewriter, splitOp.getLoc(), input, *axis, runtimeOffset, sliceSize));
|
||||||
runtimeOffset += sliceSize;
|
runtimeOffset += sliceSize;
|
||||||
}
|
}
|
||||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
|
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -29,22 +30,6 @@ static Value createTransposeInit(Value input,
|
|||||||
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
|
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
|
|
||||||
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
|
|
||||||
SmallVector<int64_t> permutation;
|
|
||||||
if (auto permAttr = transposeOp.getPermAttr()) {
|
|
||||||
permutation.reserve(permAttr.size());
|
|
||||||
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
|
|
||||||
permutation.push_back(attr.getInt());
|
|
||||||
return permutation;
|
|
||||||
}
|
|
||||||
|
|
||||||
permutation.reserve(inputType.getRank());
|
|
||||||
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
|
|
||||||
permutation.push_back(dim);
|
|
||||||
return permutation;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -56,10 +41,12 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
|||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp);
|
auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
|
||||||
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc());
|
if (failed(permutation))
|
||||||
|
return failure();
|
||||||
|
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||||
Value transposed =
|
Value transposed =
|
||||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation)
|
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation)
|
||||||
.getResult()[0];
|
.getResult()[0];
|
||||||
rewriter.replaceOp(transposeOp, transposed);
|
rewriter.replaceOp(transposeOp, transposed);
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||||
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,7 +218,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||||
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantFolder, constantOp));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,8 +230,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
|||||||
PimMemCopyHostToDevOp::create(rewriter,
|
PimMemCopyHostToDevOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputBuffer.getType(),
|
outputBuffer.getType(),
|
||||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
|
||||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
input,
|
input,
|
||||||
getTensorSizeInBytesAttr(rewriter, input))
|
getTensorSizeInBytesAttr(rewriter, input))
|
||||||
|
|||||||
@@ -326,7 +326,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||||
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,8 +370,8 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
|||||||
OperationFolder& constantFolder) {
|
OperationFolder& constantFolder) {
|
||||||
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
||||||
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
||||||
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
|
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, hostTargetOffset);
|
||||||
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
|
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
|
||||||
return PimMemCopyDevToHostOp::create(rewriter,
|
return PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
|||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||||
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
|
auto zeroIndex = getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
|
||||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||||
|
|
||||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||||
@@ -333,9 +333,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
|||||||
rewriter,
|
rewriter,
|
||||||
loc,
|
loc,
|
||||||
tensorType,
|
tensorType,
|
||||||
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
|
getOrCreateHostIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
|
||||||
getOrCreateHostIndexConstant(
|
getOrCreateHostIndexConstant(constantFolder,
|
||||||
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
|
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize) ),
|
||||||
deviceTensor,
|
deviceTensor,
|
||||||
inputTensor,
|
inputTensor,
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||||
|
|||||||
@@ -619,7 +619,7 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) {
|
Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) {
|
||||||
return getOrCreateHostIndexConstant(anchor, value, state.constantFolder);
|
return getOrCreateHostIndexConstant(state.constantFolder, anchor, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
@@ -939,7 +939,7 @@ Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, Arr
|
|||||||
|
|
||||||
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
|
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
|
||||||
auto attr = DenseIntElementsAttr::get(type, elements);
|
auto attr = DenseIntElementsAttr::get(type, elements);
|
||||||
return getOrCreateHostConstant(anchor, attr, type, state.constantFolder);
|
return getOrCreateHostConstant(state.constantFolder, anchor, attr, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool allEqual(ArrayRef<int64_t> values) {
|
bool allEqual(ArrayRef<int64_t> values) {
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
|||||||
rewriter,
|
rewriter,
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
originalType,
|
originalType,
|
||||||
getOrCreateHostIndexConstant(op, 0, constantFolder),
|
getOrCreateHostIndexConstant(constantFolder, op, 0),
|
||||||
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder),
|
getOrCreateHostIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ),
|
||||||
deviceDst,
|
deviceDst,
|
||||||
getGlobalOp.getResult(),
|
getGlobalOp.getResult(),
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||||
|
|||||||
Reference in New Issue
Block a user