Compare commits

2 Commits

Author SHA1 Message Date
NiccoloN 5b9bb0c191 refactor spatial ops
Validate Operations / validate-operations (push) Successful in 24m55s
2026-05-04 14:19:30 +02:00
NiccoloN f789954ad7 Refactor ONNXToSpatial Common and diagnostics 2026-05-04 13:42:43 +02:00
30 changed files with 2068 additions and 1875 deletions
+1 -1
View File
@@ -21,7 +21,7 @@
#include <utility> #include <utility>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -18,7 +18,9 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
Common.cpp Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
-305
View File
@@ -1,305 +0,0 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <type_traits>
#include <utility>
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool isWeightLikeComputeOperand(mlir::Value value) {
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
}; // namespace onnx_mlir
@@ -0,0 +1,8 @@
#pragma once
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ComputeRegionBuilder.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
@@ -0,0 +1,39 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
} // namespace onnx_mlir
@@ -0,0 +1,153 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,24 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include <cassert> #include "ShapeTilingUtils.hpp"
#include <optional>
#include <utility>
#include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -107,31 +93,4 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
return tensor::SplatOp::create(rewriter, loc, type, elementValue); return tensor::SplatOp::create(rewriter, loc, type, elementValue);
} }
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) { } // namespace onnx_mlir
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
}; // namespace onnx_mlir
@@ -0,0 +1,143 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape();
}
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -0,0 +1,114 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "WeightMaterialization.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
/// Returns true when a matrix-valued compute operand is ultimately backed by a
/// weight-marked constant/view chain and can be promoted into weights.
bool isWeightLikeComputeOperand(mlir::Value value);
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
/// compute body while reusing already-materialized intermediate values.
llvm::FailureOr<mlir::Value>
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
} // namespace onnx_mlir
@@ -12,14 +12,13 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <fstream> #include <fstream>
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include "Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -32,8 +31,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace { namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
@@ -50,7 +47,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private: private:
void annotateWeightsConstants(func::FuncOp funcOp) const; void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp); LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp); LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
}; };
@@ -186,7 +183,10 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc); if (failed(encapsulateGlobalInstruction(*entryFunc))) {
signalPassFailure();
return;
}
if (failed(promoteConstantInputsToWeights(*entryFunc))) { if (failed(promoteConstantInputsToWeights(*entryFunc))) {
signalPassFailure(); signalPassFailure();
@@ -287,64 +287,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
return false; return false;
} }
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
bool sourceOpernadHasWeightAlways(Operation* op) {
if (op == nullptr) if (op == nullptr)
return false; return false;
@@ -416,30 +359,32 @@ bool sourceOpernadHasWeightAlways(Operation* op) {
return res; return res;
} }
else { else {
op->dump(); op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
llvm_unreachable("Global instruction not handle in func"); return failure();
} }
} }
while (source == nullptr); while (source == nullptr);
if (hasWeightAlways(source)) return hasWeightAlways(source);
return true;
return false;
} }
// TODO what we want to keep in global? // TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
bool keep = true; bool keep = true;
while (keep) { while (keep) {
keep = false; keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>( if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
instruction) instruction)
|| isa<func::ReturnOp>(instruction) || isa<func::ReturnOp>(instruction))
|| sourceOpernadHasWeightAlways(&instruction)) continue;
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
if (failed(weightBacked))
return failure();
if (*weightBacked)
continue; continue;
keep |= encapsulateSlice(rewriter, loc, &instruction); keep |= encapsulateSlice(rewriter, loc, &instruction);
@@ -456,6 +401,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
keep |= encapsulateConcat(rewriter, loc, &instruction); keep |= encapsulateConcat(rewriter, loc, &instruction);
} }
} }
return success();
} }
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
@@ -7,11 +7,10 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -370,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wType = cast<RankedTensorType>(w.getType()); auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType()); auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); if (!xType.hasStaticShape()) {
assert("Only support 2D convolution" && xType.getRank() == 4); pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
// We need to understand what is group }
assert("Only support group=1" && convOp.getGroup() == 1); if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0); const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1); const int64_t numChannelsIn = xType.getDimSize(1);
@@ -391,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const auto dilationsAttr = convOp.getDilations(); const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads(); const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
@@ -431,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
padWidthBegin = totalPadW - padWidthEnd; padWidthBegin = totalPadW - padWidthEnd;
} }
} }
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0 // "NOTSET" or "VALID" -> all pads stay 0
} }
@@ -5,7 +5,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,13 +16,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) { static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>()) if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue()); return dyn_cast<DenseElementsAttr>(constantOp.getValue());
@@ -8,10 +8,9 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -136,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB(); Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC(); Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType()); auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType()); auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0); const int64_t numOutRows = aType.getDimSize(0);
@@ -175,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
}); });
cType = expandedType; cType = expandedType;
} }
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
cHasNumOutRows = cType.getDimSize(0) == numOutRows; cHasNumOutRows = cType.getDimSize(0) == numOutRows;
} }
@@ -199,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
} }
else else if (!isVectorShape(getTensorShape(c))) {
assert("C should be a vector" && isVectorShape(getTensorShape(c))); gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
return failure();
}
} }
auto gemvOp = ONNXGemmOp::create(rewriter, auto gemvOp = ONNXGemmOp::create(rewriter,
@@ -258,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
}); });
cType = expandedType; cType = expandedType;
} }
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
} }
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() if (!aType.hasStaticShape()) {
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv // Not a gemv
@@ -341,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute( auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back( vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
}); });
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp.getResult(0)); partialResults.push_back(computeOp->getResult(0));
} }
if (hasC) { if (hasC) {
@@ -388,14 +429,28 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB(); Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC(); Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType()); auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType()); auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType()); auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape()); if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0); const int64_t numOutRows = aType.getDimSize(0);
if (numOutRows <= 1) if (numOutRows <= 1)
@@ -438,7 +493,14 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
}); });
cType = cast<RankedTensorType>(c.getType()); cType = cast<RankedTensorType>(c.getType());
} }
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv // Row-specific bias can't share a single template body; fall through to GemmToManyGemv
if (cType.getDimSize(0) == numOutRows && numOutRows > 1) if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
return failure(); return failure();
@@ -5,7 +5,7 @@
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -6,13 +6,12 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
} }
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) { template <typename PoolOp>
assert(!values.empty() && "Expected at least one value to concatenate."); static FailureOr<Value>
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
if (values.empty()) {
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
return failure();
}
return createSpatConcat(rewriter, loc, axis, values); return createSpatConcat(rewriter, loc, axis, values);
} }
@@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
} }
template <typename ReduceOp> template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) { static FailureOr<Value>
assert(!windowValues.empty() && "Expected at least one pool window value."); reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
if (windowValues.empty()) {
op->emitOpError("pool window resolved to zero valid elements");
return failure();
}
Value reduced = windowValues.front(); Value reduced = windowValues.front();
for (Value value : windowValues.drop_front()) for (Value value : windowValues.drop_front())
@@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
return reduced; return reduced;
} }
static Value static FailureOr<Value>
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) { scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive."); if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive");
return failure();
}
if (divisor == 1) if (divisor == 1)
return reducedWindow; return reducedWindow;
@@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
double scale = 1.0 / static_cast<double>(divisor); double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale)); auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr); Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor); return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
} }
template <typename PoolOp> template <typename PoolOp>
@@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (windowValues.empty()) if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues); auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
if (failed(reducedWindow))
return failure();
Value reducedWindowValue = *reducedWindow;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) { if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1; const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor = const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size()); countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
if (failed(scaledWindow))
return failure();
reducedWindowValue = *scaledWindow;
} }
outputChannelTiles.push_back(reducedWindow); outputChannelTiles.push_back(reducedWindowValue);
} }
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
if (failed(rowPixel))
return failure();
rowPixels.push_back(*rowPixel);
} }
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
if (failed(row))
return failure();
rows.push_back(*row);
} }
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
if (failed(batchResult))
return failure();
batchResults.push_back(*batchResult);
} }
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); if (failed(pooledOutput))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -29,7 +29,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "Conversion/ONNXToSpatial/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Patterns.hpp" #include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
+3
View File
@@ -4,6 +4,9 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps add_pim_library(SpatialOps
Channels.cpp Channels.cpp
SpatialOps.cpp SpatialOps.cpp
SpatialOpsAsm.cpp
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
File diff suppressed because it is too large Load Diff
+912
View File
@@ -0,0 +1,912 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/LogicalResult.h"
#include <string>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
enum class ListDelimiter {
Square,
Paren
};
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
}
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
}
template <typename IntT>
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
while (true) {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0)
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
if (succeeded(parser.parseOptionalRSquare()))
break;
if (parser.parseComma())
return failure();
}
return success();
}
template <typename RangeT, typename PrintEntryFn>
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename IntT>
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printer << "[";
for (size_t index = 0; index < values.size();) {
if (index != 0)
printer << ", ";
auto findEqualRunEnd = [&](size_t start) {
size_t end = start + 1;
while (end < values.size() && values[end] == values[start])
++end;
return end;
};
size_t firstRunEnd = findEqualRunEnd(index);
size_t repeatCount = firstRunEnd - index;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[index];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[index]);
if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount;
if (progressionEnd > firstRunEnd && progressionValueCount >= 3) {
printer << values[index] << " to " << lastValue;
if (step != 1)
printer << " by " << step;
if (repeatCount > 1)
printer << " x" << repeatCount;
index = progressionEnd;
continue;
}
if (repeatCount > 1) {
printer << values[index] << " x" << repeatCount;
index = firstRunEnd;
continue;
}
printer << values[index];
index = firstRunEnd;
}
printer << "]";
}
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
printCloseDelimiter(printer, delimiter);
}
static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
printCloseDelimiter(printer, delimiter);
}
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
}
return success();
}
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
static ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
}
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
static void printChannelMetadata(OpAsmPrinter& printer,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
printer << " channels ";
printCompressedIntegerList(printer, channelIds);
printer << " from ";
printCompressedIntegerList(printer, sourceCoreIds);
printer << " to ";
printCompressedIntegerList(printer, targetCoreIds);
}
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
return parser.getBuilder().getDenseI64ArrayAttr(values);
}
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static void buildImplicitRegionArgs(OpAsmParser& parser,
ArrayRef<Type> inputTypes,
SmallVectorImpl<std::string>& generatedNames,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
generatedNames.reserve(inputTypes.size());
arguments.reserve(inputTypes.size());
for (auto [index, inputType] : llvm::enumerate(inputTypes)) {
generatedNames.push_back("arg" + std::to_string(index + 1));
OpAsmParser::Argument arg;
arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0};
arg.type = inputType;
arguments.push_back(arg);
}
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getOutputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
SmallVector<Type> outputTypes;
OpAsmParser::UnresolvedOperand firstOutput;
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
if (firstOutputResult.has_value()) {
if (failed(*firstOutputResult))
return failure();
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, outputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (outputs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatExtractRowsOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInput().getType());
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<Type> outputTypes;
if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parser.parseType(inputType) || parser.parseArrow()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (parser.resolveOperand(input, inputType, result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void SpatConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis();
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t axis = 0;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
Type outputType;
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (result.attributes.get("axis"))
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " core_id " << coreIdAttr.getInt();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<std::string> generatedArgNames;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t coreId = 0;
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id"));
if (hasCoreId && parser.parseInteger(coreId))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"core_id cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName)) {
printer << " core_ids ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int32_t laneCount = 0;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<std::string> generatedArgNames;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
return failure();
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids"));
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutput().getType());
}
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,35 @@
#include "mlir/IR/Block.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
if (!llvm::hasSingleElement(block))
return failure();
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
if (!yieldOp)
return failure();
for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber()));
continue;
}
}
results.push_back(yieldedValue);
}
return success();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,433 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error.
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
if (auto computeOp = dyn_cast<SpatCompute>(weightedOp->getParentOp()))
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weightedOp->getParentOp()))
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = dyn_cast<SpatComputeBatch>(weightedOp->getParentOp())) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto batchOp = op->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return failure();
return batchOp.getLaneCount();
}
static LogicalResult verifyManyChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
size_t valueCount) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.size() != valueCount)
return op->emitError("channel metadata length must match the number of values");
return success();
}
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
if (types.empty())
return op->emitError() << kind << " must carry at least one value";
Type firstType = types.front();
for (Type type : types.drop_front())
if (type != firstType)
return op->emitError() << kind << " values must all have the same type";
return success();
}
static LogicalResult verifyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
return success();
}
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return op->emitError("body must terminate with spat.yield");
if (outputTypes.empty()) {
if (yieldOp.getNumOperands() != 0)
return op->emitError("body yield must be empty when compute_batch has no results");
}
else {
if (yieldOp.getNumOperands() != 1)
return op->emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputTypes[0])
return op->emitError("body yield type must match output type");
}
for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
}
return success();
}
} // namespace
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vector1 = vectorShape[0];
int64_t vectorN = vectorShape[1];
if (vectorN != N || vector1 != 1)
return emitError("vector shape must be (1, N)");
int64_t output1 = outputShape[0];
int64_t outputM = outputShape[1];
if (outputM != M || output1 != 1)
return emitError("output shape must be (1, M)");
return success();
}
LogicalResult SpatVAddOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatExtractRowsOp::verify() {
auto inputType = dyn_cast<ShapedType>(getInput().getType());
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
return emitError("input must be a rank-2 shaped type");
int64_t numRows = inputType.getShape()[0];
int64_t numCols = inputType.getShape()[1];
Type elementType = inputType.getElementType();
if (numRows >= 0 && static_cast<int64_t>(getNumResults()) != numRows)
return emitError("number of outputs must match the number of input rows");
for (Type output : getResultTypes()) {
auto outputType = dyn_cast<ShapedType>(output);
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
return emitError("outputs must all be rank-2 shaped types");
if (outputType.getElementType() != elementType)
return emitError("output element types must match input element type");
auto outputShape = outputType.getShape();
if (outputShape[0] != 1)
return emitError("each output must have exactly one row");
if (numCols >= 0 && outputShape[1] != numCols)
return emitError("output column count must match input column count");
}
return success();
}
LogicalResult SpatConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!outputType || !outputType.hasRank())
return emitError("output must be a ranked shaped type");
int64_t axis = getAxis();
int64_t rank = outputType.getRank();
if (axis < 0 || axis >= rank)
return emitError("axis must be within the output rank");
int64_t concatenatedDimSize = 0;
bool concatenatedDimDynamic = false;
Type outputElementType = outputType.getElementType();
for (Value input : getInputs()) {
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType || !inputType.hasRank())
return emitError("inputs must be ranked shaped types");
if (inputType.getRank() != rank)
return emitError("all inputs must have the same rank as the output");
if (inputType.getElementType() != outputElementType)
return emitError("all inputs must have the same element type as the output");
for (int64_t dim = 0; dim < rank; ++dim) {
if (dim == axis)
continue;
int64_t inputDim = inputType.getDimSize(dim);
int64_t outputDim = outputType.getDimSize(dim);
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
return emitError("non-concatenated dimensions must match the output shape");
}
int64_t inputConcatDim = inputType.getDimSize(axis);
if (ShapedType::isDynamic(inputConcatDim)) {
concatenatedDimDynamic = true;
continue;
}
concatenatedDimSize += inputConcatDim;
}
int64_t outputConcatDim = outputType.getDimSize(axis);
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
return emitError("output concatenated dimension must equal the sum of input sizes");
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size())
return emitError("ComputeOp must have same number of results as yieldOp operands");
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
return emitError("ComputeOp output must be of the same type as yieldOp operand");
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
}
else {
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
}
}
else if (dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
}
}
}
for (auto arg : block.getArguments())
if (arg.use_empty())
return emitError("ComputeOp block argument is not used");
return success();
}
LogicalResult SpatChannelSendManyOp::verify() {
if (failed(verifyManyChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
}
LogicalResult SpatChannelReceiveManyOp::verify() {
if (failed(verifyManyChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
if (count <= 0)
return emitError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (getWeights().size() % laneCountSz != 0)
return emitError("number of weights must be a multiple of laneCount");
if (!getInputs().empty() && getInputs().size() != laneCountSz)
return emitError("number of inputs must be either 0 or laneCount");
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
return emitError("number of outputs must be either 0 or laneCount");
size_t weightsPerLane = getWeights().size() / laneCountSz;
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
Type weightType = getWeights()[weightIndex].getType();
for (size_t lane = 1; lane < laneCountSz; ++lane)
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
return emitError("corresponding weights across lanes must have the same type");
}
if (!getInputs().empty()) {
Type inputType = getInputs()[0].getType();
for (Value in : getInputs().drop_front())
if (in.getType() != inputType)
return emitError("all inputs must have the same type");
}
if (!getOutputs().empty()) {
Type outputType = getOutputs()[0].getType();
for (Value out : getOutputs().drop_front())
if (out.getType() != outputType)
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr)
return emitError("compute_batch core_id attribute must be a dense i32 array");
if (coreIdsAttr.size() != laneCountSz)
return emitError("compute_batch core_id array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
return emitError("compute_batch core_id values must be positive");
}
Block& block = getBody().front();
if (getInputs().empty()) {
if (block.getNumArguments() != 0)
return emitError("compute_batch body must have no block arguments when there are no inputs");
}
else {
if (block.getNumArguments() != 1)
return emitError("compute_batch body must have exactly one block argument");
if (block.getArgument(0).getType() != getInputs()[0].getType())
return emitError("body block argument type must match input type");
}
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
}
} // namespace spatial
} // namespace onnx_mlir