Refactor ONNXToSpatial Common and diagnostics
This commit is contained in:
@@ -21,7 +21,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Reshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
Common.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
|
Common/ShapeTilingUtils.cpp
|
||||||
|
Common/WeightMaterialization.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -1,305 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageN(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
using HSliceId = size_t;
|
|
||||||
using CoreId = size_t;
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr C ceilIntegerDivide(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return 1 + (ac - 1) / bc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[0] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[1] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
|
||||||
assert(isVectorShape(shape));
|
|
||||||
return shape[0] != 1 ? shape[0] : shape[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline auto getTensorShape(mlir::Value tensor) {
|
|
||||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isWeightLikeComputeOperand(mlir::Value value) {
|
|
||||||
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
|
|
||||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
|
|
||||||
|
|
||||||
while (auto* definingOp = value.getDefiningOp()) {
|
|
||||||
if (!visited.insert(definingOp).second)
|
|
||||||
return false;
|
|
||||||
if (hasWeightAlways(definingOp))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
|
|
||||||
value = extractSliceOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = expandShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = collapseShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
|
|
||||||
value = transposeOp.getData();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
|
||||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
|
||||||
return std::forward<Fn>(fn)(values[Is]...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t>
|
|
||||||
using ValueArg = mlir::Value;
|
|
||||||
|
|
||||||
template <typename Fn, typename Seq>
|
|
||||||
struct InvokeWithBlockArgsResult;
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
|
||||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Fn, typename Seq>
|
|
||||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
|
||||||
|
|
||||||
template <typename Fn>
|
|
||||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
template <typename RewriterT>
|
|
||||||
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
|
||||||
assert(!inputs.empty() && "spat.concat requires at least one input");
|
|
||||||
if (inputs.size() == 1)
|
|
||||||
return inputs.front();
|
|
||||||
|
|
||||||
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
|
||||||
auto outputShape = llvm::to_vector(firstType.getShape());
|
|
||||||
int64_t concatDimSize = 0;
|
|
||||||
bool concatDimDynamic = false;
|
|
||||||
|
|
||||||
for (mlir::Value input : inputs) {
|
|
||||||
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
|
||||||
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
|
||||||
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
|
||||||
concatDimDynamic = true;
|
|
||||||
else
|
|
||||||
concatDimSize += inputType.getDimSize(axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
|
||||||
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
||||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::TypeRange resultTypes,
|
|
||||||
mlir::ValueRange weights,
|
|
||||||
mlir::ValueRange inputs,
|
|
||||||
BodyFn&& body) {
|
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
|
||||||
for (mlir::Value input : inputs)
|
|
||||||
block->addArgument(input.getType(), loc);
|
|
||||||
|
|
||||||
computeOp.getBody().push_back(block);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
|
||||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
|
||||||
auto bodyResult =
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
|
||||||
if (mlir::failed(bodyResult)) {
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
rewriter.eraseOp(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return computeOp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename RewriterT, typename BodyFn>
|
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::TypeRange resultTypes,
|
|
||||||
mlir::ValueRange weights,
|
|
||||||
mlir::ValueRange inputs,
|
|
||||||
BodyFn&& body) {
|
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
|
||||||
for (mlir::Value input : inputs)
|
|
||||||
block->addArgument(input.getType(), loc);
|
|
||||||
|
|
||||||
computeOp.getBody().push_back(block);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
|
||||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
|
||||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
|
||||||
if (mlir::failed(bodyResult)) {
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
rewriter.eraseOp(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
|
||||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return computeOp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
|
||||||
size_t axis,
|
|
||||||
int64_t sliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
|
||||||
int64_t sliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
|
||||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
|
||||||
tileMatrix(mlir::Value& matrixToTile,
|
|
||||||
int64_t hSliceSize,
|
|
||||||
int64_t vSliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location& loc);
|
|
||||||
|
|
||||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
|
||||||
int64_t length,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
8
src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp
Normal file
8
src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "WeightMaterialization.hpp"
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||||
|
if (tensors.size() == 1)
|
||||||
|
return tensors[0];
|
||||||
|
|
||||||
|
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||||
|
SmallVector<Value> tensors2;
|
||||||
|
tensors2.reserve(tensors.size() / 2);
|
||||||
|
|
||||||
|
auto* currTensors = &tensors1;
|
||||||
|
auto* nextTensors = &tensors2;
|
||||||
|
while (currTensors->size() > 1) {
|
||||||
|
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||||
|
Value a = (*currTensors)[i];
|
||||||
|
Value b = (*currTensors)[i + 1];
|
||||||
|
rewriter.setInsertionPointAfterValue(b);
|
||||||
|
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||||
|
nextTensors->push_back(addedValue);
|
||||||
|
}
|
||||||
|
if (currTensors->size() % 2 == 1)
|
||||||
|
nextTensors->push_back(currTensors->back());
|
||||||
|
std::swap(currTensors, nextTensors);
|
||||||
|
nextTensors->clear();
|
||||||
|
}
|
||||||
|
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||||
|
return (*currTensors)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
153
src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp
Normal file
153
src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
|
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||||
|
return std::forward<Fn>(fn)(values[Is]...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t>
|
||||||
|
using ValueArg = mlir::Value;
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
struct InvokeWithBlockArgsResult;
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||||
|
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename RewriterT>
|
||||||
|
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||||
|
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||||
|
if (inputs.size() == 1)
|
||||||
|
return inputs.front();
|
||||||
|
|
||||||
|
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||||
|
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
bool concatDimDynamic = false;
|
||||||
|
|
||||||
|
for (mlir::Value input : inputs) {
|
||||||
|
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||||
|
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||||
|
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||||
|
concatDimDynamic = true;
|
||||||
|
else
|
||||||
|
concatDimSize += inputType.getDimSize(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||||
|
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
||||||
|
/// the body callback reports failure.
|
||||||
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult =
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
||||||
|
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,24 +1,10 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "Common.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -107,31 +93,4 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
|
|||||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
} // namespace onnx_mlir
|
||||||
if (tensors.size() == 1)
|
|
||||||
return tensors[0];
|
|
||||||
|
|
||||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
|
||||||
SmallVector<Value> tensors2;
|
|
||||||
tensors2.reserve(tensors.size() / 2);
|
|
||||||
|
|
||||||
auto* currTensors = &tensors1;
|
|
||||||
auto* nextTensors = &tensors2;
|
|
||||||
while (currTensors->size() > 1) {
|
|
||||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
|
||||||
Value a = (*currTensors)[i];
|
|
||||||
Value b = (*currTensors)[i + 1];
|
|
||||||
rewriter.setInsertionPointAfterValue(b);
|
|
||||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
|
||||||
nextTensors->push_back(addedValue);
|
|
||||||
}
|
|
||||||
if (currTensors->size() % 2 == 1)
|
|
||||||
nextTensors->push_back(currTensors->back());
|
|
||||||
std::swap(currTensors, nextTensors);
|
|
||||||
nextTensors->clear();
|
|
||||||
}
|
|
||||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
|
||||||
return (*currTensors)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
143
src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp
Normal file
143
src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageN(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
using HSliceId = size_t;
|
||||||
|
using CoreId = size_t;
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr C ceilIntegerDivide(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return 1 + (ac - 1) / bc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[0] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[1] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||||
|
assert(isVectorShape(shape));
|
||||||
|
return shape[0] != 1 ? shape[0] : shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline auto getTensorShape(mlir::Value tensor) {
|
||||||
|
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||||
|
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||||
|
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||||
|
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||||
|
/// at most `sliceSize` elements.
|
||||||
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
|
size_t axis,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||||
|
/// current PIM target geometry.
|
||||||
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
|
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
|
/// Tiles a matrix first across output columns and then across input rows so it
|
||||||
|
/// can be assigned to crossbars grouped by core.
|
||||||
|
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||||
|
tileMatrix(mlir::Value& matrixToTile,
|
||||||
|
int64_t hSliceSize,
|
||||||
|
int64_t vSliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location& loc);
|
||||||
|
|
||||||
|
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||||
|
int64_t length,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "WeightMaterialization.hpp"
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isWeightLikeComputeOperand(Value value) {
|
||||||
|
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
|
||||||
|
while (auto* definingOp = value.getDefiningOp()) {
|
||||||
|
if (!visited.insert(definingOp).second)
|
||||||
|
return false;
|
||||||
|
if (hasWeightAlways(definingOp))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
value = extractSliceOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||||
|
value = transposeOp.getData();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(value))
|
||||||
|
return cast<Value>(mapped);
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!tensorType || !tensorType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
sizes.reserve(tensorType.getRank());
|
||||||
|
for (int64_t dim : tensorType.getShape())
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
|
|
||||||
|
auto referencedValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||||
|
mapper.map(value, referencedValue.getResult());
|
||||||
|
return referencedValue.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
IRMapping localMapper;
|
||||||
|
for (Value operand : definingOp->getOperands()) {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||||
|
localMapper.map(operand, cast<Value>(mapped));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isWeightLikeComputeOperand(operand)) {
|
||||||
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||||
|
if (failed(clonedOperand))
|
||||||
|
return failure();
|
||||||
|
localMapper.map(operand, *clonedOperand);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
localMapper.map(operand, operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||||
|
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapper.map(oldResult, newResult);
|
||||||
|
|
||||||
|
auto mapped = mapper.lookupOrNull(value);
|
||||||
|
if (!mapped)
|
||||||
|
return failure();
|
||||||
|
return cast<Value>(mapped);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns true when a matrix-valued compute operand is ultimately backed by a
|
||||||
|
/// weight-marked constant/view chain and can be promoted into weights.
|
||||||
|
bool isWeightLikeComputeOperand(mlir::Value value);
|
||||||
|
|
||||||
|
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
|
||||||
|
/// compute body while reusing already-materialized intermediate values.
|
||||||
|
llvm::FailureOr<mlir::Value>
|
||||||
|
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -12,14 +12,13 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common/Common.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
@@ -32,8 +31,6 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
bool haveSameStaticShape(Value lhs, Value rhs);
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||||
@@ -50,7 +47,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -186,7 +183,10 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@@ -287,64 +287,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
|
||||||
if (auto mapped = mapper.lookupOrNull(value))
|
|
||||||
return cast<Value>(mapped);
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
|
||||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
|
||||||
if (!tensorType || !tensorType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
|
||||||
sizes.reserve(tensorType.getRank());
|
|
||||||
for (int64_t dim : tensorType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
|
|
||||||
auto referencedValue =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
|
||||||
mapper.map(value, referencedValue.getResult());
|
|
||||||
return referencedValue.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
IRMapping localMapper;
|
|
||||||
for (Value operand : definingOp->getOperands()) {
|
|
||||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
|
||||||
localMapper.map(operand, cast<Value>(mapped));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isWeightLikeComputeOperand(operand)) {
|
|
||||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
|
||||||
if (failed(clonedOperand))
|
|
||||||
return failure();
|
|
||||||
localMapper.map(operand, *clonedOperand);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
localMapper.map(operand, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
|
||||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
||||||
mapper.map(oldResult, newResult);
|
|
||||||
|
|
||||||
auto mapped = mapper.lookupOrNull(value);
|
|
||||||
if (!mapped)
|
|
||||||
return failure();
|
|
||||||
return cast<Value>(mapped);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool sourceOpernadHasWeightAlways(Operation* op) {
|
|
||||||
if (op == nullptr)
|
if (op == nullptr)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@@ -416,30 +359,32 @@ bool sourceOpernadHasWeightAlways(Operation* op) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
op->dump();
|
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
|
||||||
llvm_unreachable("Global instruction not handle in func");
|
return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (source == nullptr);
|
while (source == nullptr);
|
||||||
|
|
||||||
if (hasWeightAlways(source))
|
return hasWeightAlways(source);
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO what we want to keep in global?
|
// TODO what we want to keep in global?
|
||||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
bool keep = true;
|
bool keep = true;
|
||||||
while (keep) {
|
while (keep) {
|
||||||
keep = false;
|
keep = false;
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||||
|
|
||||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
||||||
instruction)
|
instruction)
|
||||||
|| isa<func::ReturnOp>(instruction)
|
|| isa<func::ReturnOp>(instruction))
|
||||||
|| sourceOpernadHasWeightAlways(&instruction))
|
continue;
|
||||||
|
|
||||||
|
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
|
||||||
|
if (failed(weightBacked))
|
||||||
|
return failure();
|
||||||
|
if (*weightBacked)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||||
@@ -456,6 +401,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
|
|||||||
@@ -7,11 +7,10 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -370,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
auto wType = cast<RankedTensorType>(w.getType());
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
if (!xType.hasStaticShape()) {
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||||
|
return failure();
|
||||||
// We need to understand what is group
|
}
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
if (!wType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (xType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (wType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (outType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (convOp.getGroup() != 1) {
|
||||||
|
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t batchSize = xType.getDimSize(0);
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
@@ -391,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const auto dilationsAttr = convOp.getDilations();
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
const auto padsAttr = convOp.getPads();
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
if (stridesAttr && stridesAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (padsAttr && padsAttr->size() != 4) {
|
||||||
|
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||||
@@ -431,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
padWidthBegin = totalPadW - padWidthEnd;
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||||
|
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
// "NOTSET" or "VALID" -> all pads stay 0
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,8 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -15,13 +16,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
|
|
||||||
strides[i] = strides[i + 1] * shape[i + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
|||||||
@@ -8,10 +8,9 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -136,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
Value b = gemmOpAdaptor.getB();
|
Value b = gemmOpAdaptor.getB();
|
||||||
Value c = gemmOpAdaptor.getC();
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
|
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
auto aType = cast<RankedTensorType>(a.getType());
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
if (!aType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t numOutRows = aType.getDimSize(0);
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
|
||||||
@@ -175,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
});
|
});
|
||||||
cType = expandedType;
|
cType = expandedType;
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||||
}
|
}
|
||||||
else
|
else if (!isVectorShape(getTensorShape(c))) {
|
||||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||||
@@ -258,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
});
|
});
|
||||||
cType = expandedType;
|
cType = expandedType;
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
if (!aType.hasStaticShape()) {
|
||||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!bType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||||
// Not a gemv
|
// Not a gemv
|
||||||
@@ -341,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
auto computeOp = createSpatCompute(
|
auto computeOp = createSpatCompute(
|
||||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||||
SmallVector<Value> vmmOutputs;
|
SmallVector<Value> vmmOutputs;
|
||||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||||
vmmOutputs.push_back(
|
vmmOutputs.push_back(
|
||||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
if (vmmOutputs.empty()) {
|
||||||
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
|
return success();
|
||||||
});
|
});
|
||||||
|
if (failed(computeOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
partialResults.push_back(computeOp.getResult(0));
|
partialResults.push_back(computeOp->getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
@@ -388,14 +429,28 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
Value b = gemmOpAdaptor.getB();
|
Value b = gemmOpAdaptor.getB();
|
||||||
Value c = gemmOpAdaptor.getC();
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
|
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
auto aType = cast<RankedTensorType>(a.getType());
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
auto bType = cast<RankedTensorType>(b.getType());
|
auto bType = cast<RankedTensorType>(b.getType());
|
||||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
if (!aType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!bType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t numOutRows = aType.getDimSize(0);
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
if (numOutRows <= 1)
|
if (numOutRows <= 1)
|
||||||
@@ -438,7 +493,14 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
});
|
});
|
||||||
cType = cast<RankedTensorType>(c.getType());
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -6,13 +6,12 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
|||||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
template <typename PoolOp>
|
||||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
static FailureOr<Value>
|
||||||
|
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||||
|
if (values.empty()) {
|
||||||
|
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
return createSpatConcat(rewriter, loc, axis, values);
|
return createSpatConcat(rewriter, loc, axis, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename ReduceOp>
|
template <typename ReduceOp>
|
||||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
static FailureOr<Value>
|
||||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
|
||||||
|
if (windowValues.empty()) {
|
||||||
|
op->emitOpError("pool window resolved to zero valid elements");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Value reduced = windowValues.front();
|
Value reduced = windowValues.front();
|
||||||
for (Value value : windowValues.drop_front())
|
for (Value value : windowValues.drop_front())
|
||||||
@@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
|
|||||||
return reduced;
|
return reduced;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static FailureOr<Value>
|
||||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
if (divisor <= 0) {
|
||||||
|
op->emitOpError("AveragePool divisor must be positive");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
if (divisor == 1)
|
if (divisor == 1)
|
||||||
return reducedWindow;
|
return reducedWindow;
|
||||||
|
|
||||||
@@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
|
|||||||
double scale = 1.0 / static_cast<double>(divisor);
|
double scale = 1.0 / static_cast<double>(divisor);
|
||||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PoolOp>
|
template <typename PoolOp>
|
||||||
@@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
if (windowValues.empty())
|
if (windowValues.empty())
|
||||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||||
|
|
||||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
|
||||||
|
if (failed(reducedWindow))
|
||||||
|
return failure();
|
||||||
|
Value reducedWindowValue = *reducedWindow;
|
||||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||||
const int64_t divisor =
|
const int64_t divisor =
|
||||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
|
||||||
|
if (failed(scaledWindow))
|
||||||
|
return failure();
|
||||||
|
reducedWindowValue = *scaledWindow;
|
||||||
}
|
}
|
||||||
|
|
||||||
outputChannelTiles.push_back(reducedWindow);
|
outputChannelTiles.push_back(reducedWindowValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
|
||||||
|
if (failed(rowPixel))
|
||||||
|
return failure();
|
||||||
|
rowPixels.push_back(*rowPixel);
|
||||||
}
|
}
|
||||||
|
|
||||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
|
||||||
|
if (failed(row))
|
||||||
|
return failure();
|
||||||
|
rows.push_back(*row);
|
||||||
}
|
}
|
||||||
|
|
||||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
|
||||||
|
if (failed(batchResult))
|
||||||
|
return failure();
|
||||||
|
batchResults.push_back(*batchResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
if (failed(pooledOutput))
|
||||||
|
return failure();
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
if (failed(computeOp))
|
if (failed(computeOp))
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Patterns.hpp"
|
#include "Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
|||||||
@@ -22,7 +22,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user