diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 07bebf3..c6606b9 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -21,7 +21,7 @@ #include #include "Common/PimCommon.hpp" -#include "Conversion/ONNXToSpatial/Common.hpp" +#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 994deff..530ad6b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -18,7 +18,9 @@ add_pim_library(OMONNXToSpatial Patterns/Tensor/Reshape.cpp Patterns/Tensor/Split.cpp ONNXToSpatialPass.cpp - Common.cpp + Common/ComputeRegionBuilder.cpp + Common/ShapeTilingUtils.cpp + Common/WeightMaterialization.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp deleted file mode 100644 index 3c49f34..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ /dev/null @@ -1,305 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Transforms/DialectConversion.h" - -#include -#include -#include - -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/STLExtras.h" - -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -namespace onnx_mlir { - -template -inline auto getImageWidth(const ShapedType& shapedType) { - return shapedType.getDimSize(2); -} - -template -inline auto getImageHeight(const ShapedType& shapedType) { - return shapedType.getDimSize(3); -} - -template -inline auto getImageChannel(const ShapedType& shapedType) { - return shapedType.getDimSize(1); -} - -template -inline auto getImageN(const ShapedType& shapedType) { - return shapedType.getDimSize(0); -} - -template -inline auto getKernelWidth(const ShapedType& shapedType) { - return shapedType.getDimSize(2); -} - -template -inline auto getKernelHeight(const ShapedType& shapedType) { - return shapedType.getDimSize(3); -} - -template -inline auto getFilterCount(const ShapedType& shapedType) { - return shapedType.getDimSize(0); -} - -using HSliceId = size_t; -using CoreId = size_t; - -template > -constexpr C ceilIntegerDivide(A a, B b) { - static_assert(std::is_integral_v, "A must be an integer type"); - static_assert(std::is_integral_v, "B must be an integer type"); - C ac = static_cast(a); - C bc = static_cast(b); - return 1 + (ac - 1) / bc; -} - -template > -constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { - static_assert(std::is_integral_v, "A must be an integer type"); - static_assert(std::is_integral_v, "B must be an integer type"); - C ac = static_cast(a); - C bc = static_cast(b); - return {ceilIntegerDivide(ac, bc), ac % bc}; -} - -template -bool isVectorShape(mlir::ArrayRef shape) { - return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); -} - -template -bool isMatrixShape(mlir::ArrayRef shape) { - return shape.size() == 2; -} - -template -bool isHVectorShape(mlir::ArrayRef shape) { - return shape.size() == 2 && shape[0] == 1; -} - -template -bool isVVectorShape(mlir::ArrayRef shape) { - return shape.size() == 2 && shape[1] == 1; -} - -template -T getVectorLength(mlir::ArrayRef shape) { - assert(isVectorShape(shape)); - return shape[0] != 1 ? shape[0] : shape[1]; -} - -inline auto getTensorShape(mlir::Value tensor) { - return mlir::cast(tensor.getType()).getShape(); -} - -inline bool isWeightLikeComputeOperand(mlir::Value value) { - auto rankedType = mlir::dyn_cast(value.getType()); - if (!rankedType || !isMatrixShape(rankedType.getShape())) - return false; - - llvm::SmallPtrSet visited; - - while (auto* definingOp = value.getDefiningOp()) { - if (!visited.insert(definingOp).second) - return false; - if (hasWeightAlways(definingOp)) - return true; - - if (auto extractSliceOp = mlir::dyn_cast(definingOp)) { - value = extractSliceOp.getSource(); - continue; - } - if (auto expandShapeOp = mlir::dyn_cast(definingOp)) { - value = expandShapeOp.getSrc(); - continue; - } - if (auto collapseShapeOp = mlir::dyn_cast(definingOp)) { - value = collapseShapeOp.getSrc(); - continue; - } - if (auto transposeOp = mlir::dyn_cast(definingOp)) { - value = transposeOp.getData(); - continue; - } - - return false; - } - - return false; -} - -namespace detail { - -inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } - -template -decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { - return std::forward(fn)(block->getArgument(Is)...); -} - -template -decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef values, std::index_sequence) { - return std::forward(fn)(values[Is]...); -} - -template -using ValueArg = mlir::Value; - -template -struct InvokeWithBlockArgsResult; - -template -struct InvokeWithBlockArgsResult> { - using type = std::invoke_result_t...>; -}; - -template -using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult::type; - -template -using InvokeWithValueRangeResultT = std::invoke_result_t; - -} // namespace detail - -template -inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) { - assert(!inputs.empty() && "spat.concat requires at least one input"); - if (inputs.size() == 1) - return inputs.front(); - - auto firstType = mlir::cast(inputs.front().getType()); - auto outputShape = llvm::to_vector(firstType.getShape()); - int64_t concatDimSize = 0; - bool concatDimDynamic = false; - - for (mlir::Value input : inputs) { - auto inputType = mlir::cast(input.getType()); - assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs"); - if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis))) - concatDimDynamic = true; - else - concatDimSize += inputType.getDimSize(axis); - } - - outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize; - auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); - return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); -} - -template -auto createSpatCompute(RewriterT& rewriter, - mlir::Location loc, - mlir::TypeRange resultTypes, - mlir::ValueRange weights, - mlir::ValueRange inputs, - BodyFn&& body) { - assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); - auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); - - auto* block = new mlir::Block(); - for (mlir::Value input : inputs) - block->addArgument(input.getType(), loc); - - computeOp.getBody().push_back(block); - rewriter.setInsertionPointToStart(block); - - using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; - if constexpr (std::is_same_v) { - auto bodyResult = - detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); - if (mlir::failed(bodyResult)) { - rewriter.setInsertionPointAfter(computeOp); - rewriter.eraseOp(computeOp); - return mlir::FailureOr(mlir::failure()); - } - rewriter.setInsertionPointAfter(computeOp); - return mlir::FailureOr(computeOp); - } - else { - static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); - detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); - - rewriter.setInsertionPointAfter(computeOp); - return computeOp; - } -} - -template -auto createSpatCompute(RewriterT& rewriter, - mlir::Location loc, - mlir::TypeRange resultTypes, - mlir::ValueRange weights, - mlir::ValueRange inputs, - BodyFn&& body) { - auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); - - auto* block = new mlir::Block(); - for (mlir::Value input : inputs) - block->addArgument(input.getType(), loc); - - computeOp.getBody().push_back(block); - rewriter.setInsertionPointToStart(block); - - using BodyResult = detail::InvokeWithValueRangeResultT>; - if constexpr (std::is_same_v) { - auto bodyResult = std::forward(body)(detail::getBlockArgs(block)); - if (mlir::failed(bodyResult)) { - rewriter.setInsertionPointAfter(computeOp); - rewriter.eraseOp(computeOp); - return mlir::FailureOr(mlir::failure()); - } - rewriter.setInsertionPointAfter(computeOp); - return mlir::FailureOr(computeOp); - } - else { - static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); - std::forward(body)(detail::getBlockArgs(block)); - - rewriter.setInsertionPointAfter(computeOp); - return computeOp; - } -} - -llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, - size_t axis, - int64_t sliceSize, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc); - -llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, - int64_t sliceSize, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc); - -llvm::DenseMap> sliceVectorPerCrossbarPerCore( - const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); - -llvm::DenseMap>> -tileMatrix(mlir::Value& matrixToTile, - int64_t hSliceSize, - int64_t vSliceSize, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location& loc); - -mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, - int64_t length, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc); - -mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); - -}; // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp new file mode 100644 index 0000000..c820073 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "ComputeRegionBuilder.hpp" +#include "ShapeTilingUtils.hpp" +#include "WeightMaterialization.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp new file mode 100644 index 0000000..f39998a --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp @@ -0,0 +1,39 @@ +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include "ComputeRegionBuilder.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { + if (tensors.size() == 1) + return tensors[0]; + + SmallVector tensors1 = {tensors.begin(), tensors.end()}; + SmallVector tensors2; + tensors2.reserve(tensors.size() / 2); + + auto* currTensors = &tensors1; + auto* nextTensors = &tensors2; + while (currTensors->size() > 1) { + for (size_t i = 0; i < currTensors->size() - 1; i += 2) { + Value a = (*currTensors)[i]; + Value b = (*currTensors)[i + 1]; + rewriter.setInsertionPointAfterValue(b); + auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); + nextTensors->push_back(addedValue); + } + if (currTensors->size() % 2 == 1) + nextTensors->push_back(currTensors->back()); + std::swap(currTensors, nextTensors); + nextTensors->clear(); + } + assert(currTensors->size() == 1 && "Expected a single input at this point."); + return (*currTensors)[0]; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp new file mode 100644 index 0000000..4ffce3e --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp @@ -0,0 +1,153 @@ +#pragma once + +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include +#include +#include + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir { + +namespace detail { + +inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } + +template +decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { + return std::forward(fn)(block->getArgument(Is)...); +} + +template +decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef values, std::index_sequence) { + return std::forward(fn)(values[Is]...); +} + +template +using ValueArg = mlir::Value; + +template +struct InvokeWithBlockArgsResult; + +template +struct InvokeWithBlockArgsResult> { + using type = std::invoke_result_t...>; +}; + +template +using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult::type; + +template +using InvokeWithValueRangeResultT = std::invoke_result_t; + +} // namespace detail + +template +inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) { + assert(!inputs.empty() && "spat.concat requires at least one input"); + if (inputs.size() == 1) + return inputs.front(); + + auto firstType = mlir::cast(inputs.front().getType()); + auto outputShape = llvm::to_vector(firstType.getShape()); + int64_t concatDimSize = 0; + bool concatDimDynamic = false; + + for (mlir::Value input : inputs) { + auto inputType = mlir::cast(input.getType()); + assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs"); + if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis))) + concatDimDynamic = true; + else + concatDimSize += inputType.getDimSize(axis); + } + + outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize; + auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); + return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); +} + +/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if +/// the body callback reports failure. +template +auto createSpatCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); + auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); + + auto* block = new mlir::Block(); + for (mlir::Value input : inputs) + block->addArgument(input.getType(), loc); + + computeOp.getBody().push_back(block); + rewriter.setInsertionPointToStart(block); + + using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; + if constexpr (std::is_same_v) { + detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + + rewriter.setInsertionPointAfter(computeOp); + return computeOp; + } + else { + auto bodyResult = + detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + if (mlir::failed(bodyResult)) { + rewriter.setInsertionPointAfter(computeOp); + rewriter.eraseOp(computeOp); + return mlir::FailureOr(mlir::failure()); + } + rewriter.setInsertionPointAfter(computeOp); + return mlir::FailureOr(computeOp); + } +} + +/// Builds a `spat.compute` whose body consumes the block arguments as a single +/// `ValueRange`, which is convenient for variadic reductions/concats. +template +auto createSpatCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); + + auto* block = new mlir::Block(); + for (mlir::Value input : inputs) + block->addArgument(input.getType(), loc); + + computeOp.getBody().push_back(block); + rewriter.setInsertionPointToStart(block); + + using BodyResult = detail::InvokeWithValueRangeResultT>; + if constexpr (std::is_same_v) { + std::forward(body)(detail::getBlockArgs(block)); + + rewriter.setInsertionPointAfter(computeOp); + return computeOp; + } + else { + auto bodyResult = std::forward(body)(detail::getBlockArgs(block)); + if (mlir::failed(bodyResult)) { + rewriter.setInsertionPointAfter(computeOp); + rewriter.eraseOp(computeOp); + return mlir::FailureOr(mlir::failure()); + } + rewriter.setInsertionPointAfter(computeOp); + return mlir::FailureOr(computeOp); + } +} + +mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp similarity index 71% rename from src/PIM/Conversion/ONNXToSpatial/Common.cpp rename to src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 066623e..41c629d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -1,24 +1,10 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/Casting.h" -#include -#include -#include - -#include "Common.hpp" +#include "ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -107,31 +93,4 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr return tensor::SplatOp::create(rewriter, loc, type, elementValue); } -Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { - if (tensors.size() == 1) - return tensors[0]; - - SmallVector tensors1 = {tensors.begin(), tensors.end()}; - SmallVector tensors2; - tensors2.reserve(tensors.size() / 2); - - auto* currTensors = &tensors1; - auto* nextTensors = &tensors2; - while (currTensors->size() > 1) { - for (size_t i = 0; i < currTensors->size() - 1; i += 2) { - Value a = (*currTensors)[i]; - Value b = (*currTensors)[i + 1]; - rewriter.setInsertionPointAfterValue(b); - auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); - nextTensors->push_back(addedValue); - } - if (currTensors->size() % 2 == 1) - nextTensors->push_back(currTensors->back()); - std::swap(currTensors, nextTensors); - nextTensors->clear(); - } - assert(currTensors->size() == 1 && "Expected a single input at this point."); - return (*currTensors)[0]; -} - -}; // namespace onnx_mlir +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp new file mode 100644 index 0000000..6c433a9 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace onnx_mlir { + +template +inline auto getImageWidth(const ShapedType& shapedType) { + return shapedType.getDimSize(2); +} + +template +inline auto getImageHeight(const ShapedType& shapedType) { + return shapedType.getDimSize(3); +} + +template +inline auto getImageChannel(const ShapedType& shapedType) { + return shapedType.getDimSize(1); +} + +template +inline auto getImageN(const ShapedType& shapedType) { + return shapedType.getDimSize(0); +} + +template +inline auto getKernelWidth(const ShapedType& shapedType) { + return shapedType.getDimSize(2); +} + +template +inline auto getKernelHeight(const ShapedType& shapedType) { + return shapedType.getDimSize(3); +} + +template +inline auto getFilterCount(const ShapedType& shapedType) { + return shapedType.getDimSize(0); +} + +using HSliceId = size_t; +using CoreId = size_t; + +template > +constexpr C ceilIntegerDivide(A a, B b) { + static_assert(std::is_integral_v, "A must be an integer type"); + static_assert(std::is_integral_v, "B must be an integer type"); + C ac = static_cast(a); + C bc = static_cast(b); + return 1 + (ac - 1) / bc; +} + +template > +constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { + static_assert(std::is_integral_v, "A must be an integer type"); + static_assert(std::is_integral_v, "B must be an integer type"); + C ac = static_cast(a); + C bc = static_cast(b); + return {ceilIntegerDivide(ac, bc), ac % bc}; +} + +template +bool isVectorShape(mlir::ArrayRef shape) { + return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); +} + +template +bool isMatrixShape(mlir::ArrayRef shape) { + return shape.size() == 2; +} + +template +bool isHVectorShape(mlir::ArrayRef shape) { + return shape.size() == 2 && shape[0] == 1; +} + +template +bool isVVectorShape(mlir::ArrayRef shape) { + return shape.size() == 2 && shape[1] == 1; +} + +template +T getVectorLength(mlir::ArrayRef shape) { + assert(isVectorShape(shape)); + return shape[0] != 1 ? shape[0] : shape[1]; +} + +inline auto getTensorShape(mlir::Value tensor) { + return mlir::cast(tensor.getType()).getShape(); +} + +inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) { + auto lhsType = mlir::dyn_cast(lhs.getType()); + auto rhsType = mlir::dyn_cast(rhs.getType()); + return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape(); +} + +/// Slices a statically shaped tensor along one axis into contiguous pieces of +/// at most `sliceSize` elements. +llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, + size_t axis, + int64_t sliceSize, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); + +llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, + int64_t sliceSize, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); + +/// Partitions one logical vector into per-core crossbar-sized slices using the +/// current PIM target geometry. +llvm::DenseMap> sliceVectorPerCrossbarPerCore( + const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); + +/// Tiles a matrix first across output columns and then across input rows so it +/// can be assigned to crossbars grouped by core. +llvm::DenseMap>> +tileMatrix(mlir::Value& matrixToTile, + int64_t hSliceSize, + int64_t vSliceSize, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location& loc); + +mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, + int64_t length, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp new file mode 100644 index 0000000..4645aa0 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp @@ -0,0 +1,114 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" + +#include "WeightMaterialization.hpp" +#include "ShapeTilingUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +bool isWeightLikeComputeOperand(Value value) { + auto rankedType = dyn_cast(value.getType()); + if (!rankedType || !isMatrixShape(rankedType.getShape())) + return false; + + llvm::SmallPtrSet visited; + + while (auto* definingOp = value.getDefiningOp()) { + if (!visited.insert(definingOp).second) + return false; + if (hasWeightAlways(definingOp)) + return true; + + if (auto extractSliceOp = dyn_cast(definingOp)) { + value = extractSliceOp.getSource(); + continue; + } + if (auto expandShapeOp = dyn_cast(definingOp)) { + value = expandShapeOp.getSrc(); + continue; + } + if (auto collapseShapeOp = dyn_cast(definingOp)) { + value = collapseShapeOp.getSrc(); + continue; + } + if (auto transposeOp = dyn_cast(definingOp)) { + value = transposeOp.getData(); + continue; + } + + return false; + } + + return false; +} + +FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { + if (auto mapped = mapper.lookupOrNull(value)) + return cast(mapped); + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return failure(); + + if (isa(definingOp)) { + auto tensorType = dyn_cast(value.getType()); + if (!tensorType || !tensorType.hasStaticShape()) + return failure(); + + SmallVector offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(tensorType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(tensorType.getRank()); + for (int64_t dim : tensorType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + + auto referencedValue = + tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides); + mapper.map(value, referencedValue.getResult()); + return referencedValue.getResult(); + } + + if (!isa(definingOp)) + return failure(); + + IRMapping localMapper; + for (Value operand : definingOp->getOperands()) { + if (auto mapped = mapper.lookupOrNull(operand)) { + localMapper.map(operand, cast(mapped)); + continue; + } + + if (isWeightLikeComputeOperand(operand)) { + auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); + if (failed(clonedOperand)) + return failure(); + localMapper.map(operand, *clonedOperand); + continue; + } + + localMapper.map(operand, operand); + } + + Operation* clonedOp = rewriter.clone(*definingOp, localMapper); + for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) + mapper.map(oldResult, newResult); + + auto mapped = mapper.lookupOrNull(value); + if (!mapped) + return failure(); + return cast(mapped); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp new file mode 100644 index 0000000..15c77a6 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" + +namespace onnx_mlir { + +/// Returns true when a matrix-valued compute operand is ultimately backed by a +/// weight-marked constant/view chain and can be promoted into weights. +bool isWeightLikeComputeOperand(mlir::Value value); + +/// Rebuilds the view/transpose chain of a promoted weight operand inside a new +/// compute body while reusing already-materialized intermediate values. +llvm::FailureOr +materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 20fc99e..07ff217 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -12,14 +12,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include -#include "Common.hpp" +#include "Common/Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" @@ -32,8 +31,6 @@ using namespace mlir; namespace onnx_mlir { -bool haveSameStaticShape(Value lhs, Value rhs); - namespace { #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" @@ -50,7 +47,7 @@ struct ONNXToSpatialPass : PassWrapper materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { - if (auto mapped = mapper.lookupOrNull(value)) - return cast(mapped); - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return failure(); - - if (isa(definingOp)) { - auto tensorType = dyn_cast(value.getType()); - if (!tensorType || !tensorType.hasStaticShape()) - return failure(); - - SmallVector offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); - SmallVector sizes; - SmallVector strides(tensorType.getRank(), rewriter.getIndexAttr(1)); - sizes.reserve(tensorType.getRank()); - for (int64_t dim : tensorType.getShape()) - sizes.push_back(rewriter.getIndexAttr(dim)); - - auto referencedValue = - tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides); - mapper.map(value, referencedValue.getResult()); - return referencedValue.getResult(); - } - - if (!isa(definingOp)) - return failure(); - - IRMapping localMapper; - for (Value operand : definingOp->getOperands()) { - if (auto mapped = mapper.lookupOrNull(operand)) { - localMapper.map(operand, cast(mapped)); - continue; - } - - if (isWeightLikeComputeOperand(operand)) { - auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); - if (failed(clonedOperand)) - return failure(); - localMapper.map(operand, *clonedOperand); - continue; - } - - localMapper.map(operand, operand); - } - - Operation* clonedOp = rewriter.clone(*definingOp, localMapper); - for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) - mapper.map(oldResult, newResult); - - auto mapped = mapper.lookupOrNull(value); - if (!mapped) - return failure(); - return cast(mapped); -} - -bool sourceOpernadHasWeightAlways(Operation* op) { +static FailureOr sourceOperandHasWeightAlways(Operation* op) { if (op == nullptr) return false; @@ -416,30 +359,32 @@ bool sourceOpernadHasWeightAlways(Operation* op) { return res; } else { - op->dump(); - llvm_unreachable("Global instruction not handle in func"); + op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes"); + return failure(); } } while (source == nullptr); - if (hasWeightAlways(source)) - return true; - return false; + return hasWeightAlways(source); } // TODO what we want to keep in global? -void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { +LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); IRRewriter rewriter(&getContext()); bool keep = true; while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { - if (isa( instruction) - || isa(instruction) - || sourceOpernadHasWeightAlways(&instruction)) + || isa(instruction)) + continue; + + auto weightBacked = sourceOperandHasWeightAlways(&instruction); + if (failed(weightBacked)) + return failure(); + if (*weightBacked) continue; keep |= encapsulateSlice(rewriter, loc, &instruction); @@ -456,6 +401,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { keep |= encapsulateConcat(rewriter, loc, &instruction); } } + return success(); } void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 70bf871..4c0633e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -7,11 +7,10 @@ #include "llvm/ADT/SmallVector.h" #include -#include -#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -370,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, auto wType = cast(w.getType()); auto outType = cast(convOp.getY().getType()); - assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); - assert("Only support 2D convolution" && xType.getRank() == 4); - - // We need to understand what is group - assert("Only support group=1" && convOp.getGroup() == 1); + if (!xType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); + return failure(); + } + if (!wType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight"); + return failure(); + } + if (!outType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result"); + return failure(); + } + if (xType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4}); + return failure(); + } + if (wType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4}); + return failure(); + } + if (outType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4}); + return failure(); + } + if (convOp.getGroup() != 1) { + convOp.emitOpError("only group=1 convolution is supported for Spatial lowering"); + return failure(); + } const int64_t batchSize = xType.getDimSize(0); const int64_t numChannelsIn = xType.getDimSize(1); @@ -391,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, const auto dilationsAttr = convOp.getDilations(); const auto padsAttr = convOp.getPads(); + if (stridesAttr && stridesAttr->size() != 2) { + convOp.emitOpError("requires exactly two stride values for Spatial lowering"); + return failure(); + } + if (dilationsAttr && dilationsAttr->size() != 2) { + convOp.emitOpError("requires exactly two dilation values for Spatial lowering"); + return failure(); + } + if (padsAttr && padsAttr->size() != 4) { + convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering"); + return failure(); + } + const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; @@ -431,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, padWidthBegin = totalPadW - padWidthEnd; } } + else if (autoPad != "NOTSET" && autoPad != "VALID") { + convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; + return failure(); + } // "NOTSET" or "VALID" -> all pads stay 0 } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp index 35c8e7d..d3ac767 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp @@ -5,7 +5,8 @@ #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -15,13 +16,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static SmallVector computeRowMajorStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - for (int64_t i = static_cast(shape.size()) - 2; i >= 0; --i) - strides[i] = strides[i + 1] * shape[i + 1]; - return strides; -} - static DenseElementsAttr getDenseConstantAttr(Value value) { if (auto constantOp = value.getDefiningOp()) return dyn_cast(constantOp.getValue()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 55ff2c5..52dd163 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -8,10 +8,9 @@ #include "llvm/ADT/SmallVector.h" -#include - #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -136,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); - assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); + if (gemmOpAdaptor.getTransA()) { + gemmOp.emitOpError("requires transA=false before Gemm row decomposition"); + return failure(); + } bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto outType = cast(gemmOp.getY().getType()); - assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); + if (!aType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); + return failure(); + } + if (!outType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); + return failure(); + } const int64_t numOutRows = aType.getDimSize(0); @@ -175,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, }); cType = expandedType; } - assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + if (!cType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); + return failure(); + } + if (cType.getRank() != 2) { + pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); + return failure(); + } cHasNumOutRows = cType.getDimSize(0) == numOutRows; } @@ -199,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); } - else - assert("C should be a vector" && isVectorShape(getTensorShape(c))); + else if (!isVectorShape(getTensorShape(c))) { + gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows"); + return failure(); + } } auto gemvOp = ONNXGemmOp::create(rewriter, @@ -258,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, }); cType = expandedType; } - assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + if (!cType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); + return failure(); + } + if (cType.getRank() != 2) { + pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); + return failure(); + } } - assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() - && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); + if (!aType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); + return failure(); + } + if (!bType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B"); + return failure(); + } + if (!outType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); + return failure(); + } if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) // Not a gemv @@ -341,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, weights.push_back(bTiles[outSliceId][coreId][aSliceId]); auto computeOp = createSpatCompute( - rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { + rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult { SmallVector vmmOutputs; vmmOutputs.reserve(aHSlicesArgs.size()); for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) vmmOutputs.push_back( spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); - assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); + if (vmmOutputs.empty()) { + gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); + return failure(); + } Value partialVmmSum = sumTensors(vmmOutputs, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); + return success(); }); + if (failed(computeOp)) + return failure(); - partialResults.push_back(computeOp.getResult(0)); + partialResults.push_back(computeOp->getResult(0)); } if (hasC) { @@ -388,14 +429,28 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); - assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); + if (gemmOpAdaptor.getTransA()) { + gemmOp.emitOpError("requires transA=false before batch Gemm lowering"); + return failure(); + } bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto outType = cast(gemmOp.getY().getType()); - assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape()); + if (!aType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); + return failure(); + } + if (!bType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B"); + return failure(); + } + if (!outType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); + return failure(); + } const int64_t numOutRows = aType.getDimSize(0); if (numOutRows <= 1) @@ -438,7 +493,14 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, }); cType = cast(c.getType()); } - assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + if (!cType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); + return failure(); + } + if (cType.getRank() != 2) { + pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); + return failure(); + } // Row-specific bias can't share a single template body; fall through to GemmToManyGemv if (cType.getDimSize(0) == numOutRows && numOutRows > 1) return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 0cbe033..a82adcc 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -5,7 +5,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index d236593..252ee9d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -5,7 +5,7 @@ #include -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index 56df1f9..f74b7a0 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -6,13 +6,12 @@ #include "llvm/ADT/SmallVector.h" #include -#include #include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional arrayAttr, size_t index, return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; } -static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef values) { - assert(!values.empty() && "Expected at least one value to concatenate."); +template +static FailureOr +concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef values) { + if (values.empty()) { + poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty"); + return failure(); + } return createSpatConcat(rewriter, loc, axis, values); } @@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca } template -static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef windowValues) { - assert(!windowValues.empty() && "Expected at least one pool window value."); +static FailureOr +reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef windowValues) { + if (windowValues.empty()) { + op->emitOpError("pool window resolved to zero valid elements"); + return failure(); + } Value reduced = windowValues.front(); for (Value value : windowValues.drop_front()) @@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo return reduced; } -static Value -scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) { - assert(divisor > 0 && "AveragePool divisor must be positive."); +static FailureOr +scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) { + if (divisor <= 0) { + op->emitOpError("AveragePool divisor must be positive"); + return failure(); + } if (divisor == 1) return reducedWindow; @@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu double scale = 1.0 / static_cast(divisor); auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale)); Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr); - return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor); + return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult(); } template @@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { if (windowValues.empty()) return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); - Value reducedWindow = reduceWindowValues(rewriter, loc, windowValues); + auto reducedWindow = reduceWindowValues(rewriter, loc, poolOp, windowValues); + if (failed(reducedWindow)) + return failure(); + Value reducedWindowValue = *reducedWindow; if constexpr (std::is_same_v) { const bool countIncludePad = poolOp.getCountIncludePad() == 1; const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : static_cast(windowValues.size()); - reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); + auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor); + if (failed(scaledWindow)) + return failure(); + reducedWindowValue = *scaledWindow; } - outputChannelTiles.push_back(reducedWindow); + outputChannelTiles.push_back(reducedWindowValue); } - rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); + auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles); + if (failed(rowPixel)) + return failure(); + rowPixels.push_back(*rowPixel); } - rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); + auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels); + if (failed(row)) + return failure(); + rows.push_back(*row); } - batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); + auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows); + if (failed(batchResult)) + return failure(); + batchResults.push_back(*batchResult); } - Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); - spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); + auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults); + if (failed(pooledOutput)) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput); return success(); }); if (failed(computeOp)) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp index b922581..9f256b7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp @@ -1,6 +1,6 @@ #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Sigmoid.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Sigmoid.cpp index 1fc13e8..a56cc86 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Sigmoid.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Sigmoid.cpp @@ -1,6 +1,6 @@ #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index 0f43e99..1e58276 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index 87f8a21..64f8805 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp index 2e59cf1..e388b83 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -5,7 +5,7 @@ #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index 2e6dbcf..39eebe2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -5,7 +5,7 @@ #include -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp index a9ba74a..1e6c93b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index cc56f01..86f3bd5 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -29,7 +29,7 @@ #include #include -#include "Conversion/ONNXToSpatial/Common.hpp" +#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 66d2fef..a1cb7aa 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -22,7 +22,7 @@ #include #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"