Refactor ONNXToSpatial Common and diagnostics
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
@@ -0,0 +1,39 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
if (tensors.size() == 1)
|
||||
return tensors[0];
|
||||
|
||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||
SmallVector<Value> tensors2;
|
||||
tensors2.reserve(tensors.size() / 2);
|
||||
|
||||
auto* currTensors = &tensors1;
|
||||
auto* nextTensors = &tensors2;
|
||||
while (currTensors->size() > 1) {
|
||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||
Value a = (*currTensors)[i];
|
||||
Value b = (*currTensors)[i + 1];
|
||||
rewriter.setInsertionPointAfterValue(b);
|
||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||
nextTensors->push_back(addedValue);
|
||||
}
|
||||
if (currTensors->size() % 2 == 1)
|
||||
nextTensors->push_back(currTensors->back());
|
||||
std::swap(currTensors, nextTensors);
|
||||
nextTensors->clear();
|
||||
}
|
||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||
return (*currTensors)[0];
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,153 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
template <size_t>
|
||||
using ValueArg = mlir::Value;
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
struct InvokeWithBlockArgsResult;
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||
};
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||
|
||||
template <typename Fn>
|
||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename RewriterT>
|
||||
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||
if (inputs.size() == 1)
|
||||
return inputs.front();
|
||||
|
||||
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||
int64_t concatDimSize = 0;
|
||||
bool concatDimDynamic = false;
|
||||
|
||||
for (mlir::Value input : inputs) {
|
||||
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||
concatDimDynamic = true;
|
||||
else
|
||||
concatDimSize += inputType.getDimSize(axis);
|
||||
}
|
||||
|
||||
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||
}
|
||||
|
||||
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
||||
/// the body callback reports failure.
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
||||
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,96 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
assert("Invalid axis" && axis < shape.size());
|
||||
|
||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (const auto size : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(size));
|
||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||
|
||||
long length = shape[axis];
|
||||
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(numSlices);
|
||||
|
||||
for (int64_t i = 0; i < numSlices; i++) {
|
||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
||||
return slices;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||
assert("Not a vector" && isVectorShape(shape));
|
||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
|
||||
}
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>>
|
||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||
size_t coreId = sliceId / crossbarCountInCore;
|
||||
slicesPerCore[coreId].push_back(slices[sliceId]);
|
||||
}
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
|
||||
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
|
||||
|
||||
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
|
||||
size_t numHSlices = hSlices.size();
|
||||
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
|
||||
Value hSlice = hSlices[hSliceId];
|
||||
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
|
||||
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
|
||||
size_t coreId = vSliceId / crossbarCountInCore;
|
||||
Value vSlice = vSlices[vSliceId];
|
||||
tiles[hSliceId][coreId].push_back(vSlice);
|
||||
}
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(1);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageN(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[1] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||
assert(isVectorShape(shape));
|
||||
return shape[0] != 1 ? shape[0] : shape[1];
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
/// at most `sliceSize` elements.
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||
/// current PIM target geometry.
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
/// Tiles a matrix first across output columns and then across input rows so it
|
||||
/// can be assigned to crossbars grouped by core.
|
||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||
tileMatrix(mlir::Value& matrixToTile,
|
||||
int64_t hSliceSize,
|
||||
int64_t vSliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,114 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isWeightLikeComputeOperand(Value value) {
|
||||
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (hasWeightAlways(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||
if (auto mapped = mapper.lookupOrNull(value))
|
||||
return cast<Value>(mapped);
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!tensorType || !tensorType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(tensorType.getRank());
|
||||
for (int64_t dim : tensorType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
|
||||
auto referencedValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||
mapper.map(value, referencedValue.getResult());
|
||||
return referencedValue.getResult();
|
||||
}
|
||||
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
return failure();
|
||||
|
||||
IRMapping localMapper;
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||
localMapper.map(operand, cast<Value>(mapped));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWeightLikeComputeOperand(operand)) {
|
||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||
if (failed(clonedOperand))
|
||||
return failure();
|
||||
localMapper.map(operand, *clonedOperand);
|
||||
continue;
|
||||
}
|
||||
|
||||
localMapper.map(operand, operand);
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
auto mapped = mapper.lookupOrNull(value);
|
||||
if (!mapped)
|
||||
return failure();
|
||||
return cast<Value>(mapped);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Returns true when a matrix-valued compute operand is ultimately backed by a
|
||||
/// weight-marked constant/view chain and can be promoted into weights.
|
||||
bool isWeightLikeComputeOperand(mlir::Value value);
|
||||
|
||||
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
|
||||
/// compute body while reusing already-materialized intermediate values.
|
||||
llvm::FailureOr<mlir::Value>
|
||||
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user