From f492400edaf711ded34fda16f7a86a911d4993e7 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 29 Jun 2026 14:00:10 +0200 Subject: [PATCH] refactor --- src/PIM/Common/CMakeLists.txt | 1 + src/PIM/Common/IR/BatchCoreUtils.cpp | 60 ++ src/PIM/Common/IR/BatchCoreUtils.hpp | 14 + .../Common => Common/IR}/IndexingUtils.cpp | 2 +- .../Common => Common/IR}/IndexingUtils.hpp | 0 src/PIM/Common/IR/ShapeUtils.cpp | 71 ++ src/PIM/Common/IR/ShapeUtils.hpp | 71 ++ src/PIM/Common/LabeledList.hpp | 315 ------ src/PIM/Common/PimCommon.hpp | 1 + .../Conversion/ONNXToSpatial/CMakeLists.txt | 3 +- .../ONNXToSpatial/Common/Common.hpp | 2 +- .../Common/MatrixProductLowering.cpp | 48 + .../Common/MatrixProductLowering.hpp | 20 + .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 70 -- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 76 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 142 +-- .../Patterns/Math/ConvGeometry.cpp | 77 ++ .../Patterns/Math/ConvGeometry.hpp | 86 ++ .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 38 - .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 38 +- .../BatchCoreLoweringPatterns.cpp | 99 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 14 +- .../SpatialToPim/SpatialToPimPass.cpp | 157 +-- .../SpatialToPim/SpatialToPimPass.hpp | 1 - src/PIM/Dialect/Spatial/CMakeLists.txt | 2 + .../HostOutputFinalization.cpp | 134 +++ .../HostOutputFinalization.hpp | 11 + .../MaterializeMergeSchedule.cpp | 913 +++--------------- .../MaterializedClassState.hpp | 252 +++++ .../MergeComputeNodesPass.cpp | 25 +- .../MergeComputeNodes/MergeMessages.hpp | 67 ++ .../MergeComputeNodes/MergeScheduleKeys.hpp | 134 +++ .../MergeComputeNodes/ProjectedFragments.cpp | 104 ++ .../MergeComputeNodes/ProjectedFragments.hpp | 87 ++ .../MergeComputeNodes/Scheduling/Utils.hpp | 1 - test/PIM/CMakeLists.txt | 7 - test/PIM/LabeledListTest.cpp | 162 ---- 37 files changed, 1407 insertions(+), 1898 deletions(-) rename src/PIM/{Conversion/ONNXToSpatial/Common => Common/IR}/IndexingUtils.cpp (96%) rename src/PIM/{Conversion/ONNXToSpatial/Common => Common/IR}/IndexingUtils.hpp (100%) delete mode 100644 src/PIM/Common/LabeledList.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Common/MatrixProductLowering.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Common/MatrixProductLowering.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.cpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.hpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializedClassState.hpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeMessages.hpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeScheduleKeys.hpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.cpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.hpp delete mode 100644 test/PIM/LabeledListTest.cpp diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 3ef3168..46f609b 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -5,6 +5,7 @@ add_pim_library(OMPimCommon IR/ConstantUtils.cpp IR/CoreBlockUtils.cpp IR/EntryPointUtils.cpp + IR/IndexingUtils.cpp IR/LoopUtils.cpp IR/ShapeUtils.cpp IR/SubviewUtils.cpp diff --git a/src/PIM/Common/IR/BatchCoreUtils.cpp b/src/PIM/Common/IR/BatchCoreUtils.cpp index 562b8d3..957f170 100644 --- a/src/PIM/Common/IR/BatchCoreUtils.cpp +++ b/src/PIM/Common/IR/BatchCoreUtils.cpp @@ -1,5 +1,6 @@ #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" namespace onnx_mlir { @@ -9,6 +10,65 @@ llvm::SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { return llvm::SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); } +mlir::FailureOr> +getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) { + auto coreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName); + if (!coreIdAttr) + return std::optional {}; + if (coreIdAttr.getInt() < 0) { + computeOp.emitOpError() << fieldName << " must be non-negative"; + return mlir::failure(); + } + auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), computeOp, fieldName); + if (mlir::failed(checkedCoreId)) + return mlir::failure(); + return std::optional {*checkedCoreId}; +} + +mlir::FailureOr getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) { + auto coreId = getOptionalScheduledCoreId(computeOp, fieldName); + if (mlir::failed(coreId)) + return mlir::failure(); + if (!*coreId) { + computeOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdAttrName; + return mlir::failure(); + } + return **coreId; +} + +mlir::FailureOr>> +getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) { + auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + if (!coreIdsAttr) + return std::optional> {}; + + llvm::SmallVector coreIds; + coreIds.reserve(coreIdsAttr.size()); + for (int32_t coreId : coreIdsAttr.asArrayRef()) { + if (coreId < 0) { + computeBatchOp.emitOpError() << fieldName << " values must be non-negative"; + return mlir::failure(); + } + auto checkedCoreId = pim::checkedI32(static_cast(coreId), computeBatchOp, fieldName); + if (mlir::failed(checkedCoreId)) + return mlir::failure(); + coreIds.push_back(*checkedCoreId); + } + return std::optional> {std::move(coreIds)}; +} + +mlir::FailureOr> +getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) { + auto coreIds = getOptionalScheduledBatchCoreIds(computeBatchOp, fieldName); + if (mlir::failed(coreIds)) + return mlir::failure(); + if (!*coreIds) { + computeBatchOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdsAttrName; + return mlir::failure(); + } + return std::move(**coreIds); +} + llvm::SmallVector getLaneChunkCoreIds(llvm::ArrayRef coreIds, size_t laneCount, unsigned lane) { llvm::SmallVector laneCoreIds; laneCoreIds.reserve(coreIds.size() / laneCount); diff --git a/src/PIM/Common/IR/BatchCoreUtils.hpp b/src/PIM/Common/IR/BatchCoreUtils.hpp index 58eb57b..959b78a 100644 --- a/src/PIM/Common/IR/BatchCoreUtils.hpp +++ b/src/PIM/Common/IR/BatchCoreUtils.hpp @@ -3,12 +3,26 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include + #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { llvm::SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp); +mlir::FailureOr> +getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName); + +mlir::FailureOr getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName); + +mlir::FailureOr>> +getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName); + +mlir::FailureOr> +getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName); + llvm::SmallVector getLaneChunkCoreIds(llvm::ArrayRef coreIds, size_t laneCount, unsigned lane); bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp b/src/PIM/Common/IR/IndexingUtils.cpp similarity index 96% rename from src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp rename to src/PIM/Common/IR/IndexingUtils.cpp index 0033b72..3635cd9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp +++ b/src/PIM/Common/IR/IndexingUtils.cpp @@ -1,6 +1,6 @@ #include -#include "IndexingUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp" using namespace mlir; diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp b/src/PIM/Common/IR/IndexingUtils.hpp similarity index 100% rename from src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp rename to src/PIM/Common/IR/IndexingUtils.hpp diff --git a/src/PIM/Common/IR/ShapeUtils.cpp b/src/PIM/Common/IR/ShapeUtils.cpp index 112b8aa..3ae9b0a 100644 --- a/src/PIM/Common/IR/ShapeUtils.cpp +++ b/src/PIM/Common/IR/ShapeUtils.cpp @@ -1,6 +1,9 @@ #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" +#include + #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" namespace onnx_mlir { @@ -163,4 +166,72 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef sourceShape, return true; } +bool hasStaticPositiveShape(llvm::ArrayRef shape) { + return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); +} + +bool hasStaticPositiveShape(mlir::RankedTensorType type) { + return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); +} + +int64_t getStaticShapeElementCount(llvm::ArrayRef shape) { + return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); +} + +llvm::SmallVector permuteShape(llvm::ArrayRef shape, llvm::ArrayRef permutation) { + llvm::SmallVector permutedShape; + permutedShape.reserve(permutation.size()); + for (int64_t axis : permutation) + permutedShape.push_back(shape[axis]); + return permutedShape; +} + +llvm::SmallVector invertPermutation(llvm::ArrayRef permutation) { + llvm::SmallVector inversePermutation(permutation.size()); + for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) + inversePermutation[oldIndex] = static_cast(newIndex); + return inversePermutation; +} + +mlir::FailureOr> +getTransposePermutationChecked(std::optional permAttr, int64_t rank) { + llvm::SmallVector permutation; + if (!permAttr) { + permutation.reserve(rank); + for (int64_t dim = rank - 1; dim >= 0; --dim) + permutation.push_back(dim); + return permutation; + } + + if (static_cast(permAttr->size()) != rank) + return mlir::failure(); + + permutation.reserve(permAttr->size()); + llvm::SmallVector seen(rank, false); + for (mlir::IntegerAttr attr : permAttr->getAsRange()) { + int64_t axis = attr.getInt(); + if (axis < 0 || axis >= rank || seen[axis]) + return mlir::failure(); + seen[axis] = true; + permutation.push_back(axis); + } + return permutation; +} + +llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) { + return llvm::SmallVector(rank, rewriter.getIndexAttr(1)); +} + +llvm::SmallVector getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank) { + return llvm::SmallVector(rank, rewriter.getIndexAttr(0)); +} + +llvm::SmallVector getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef shape) { + llvm::SmallVector sizes; + sizes.reserve(shape.size()); + for (int64_t dim : shape) + sizes.push_back(rewriter.getIndexAttr(dim)); + return sizes; +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ShapeUtils.hpp b/src/PIM/Common/IR/ShapeUtils.hpp index 4aa08be..b2d567e 100644 --- a/src/PIM/Common/IR/ShapeUtils.hpp +++ b/src/PIM/Common/IR/ShapeUtils.hpp @@ -2,15 +2,23 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include +#include +#include +#include namespace onnx_mlir { +using HSliceId = size_t; +using CoreId = size_t; + llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape); llvm::SmallVector @@ -36,4 +44,67 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef sourceShape, llvm::ArrayRef staticSizes, llvm::ArrayRef staticStrides); +template > +constexpr C ceilIntegerDivide(A a, B b) { + static_assert(std::is_integral_v, "A must be an integer type"); + static_assert(std::is_integral_v, "B must be an integer type"); + C ac = static_cast(a); + C bc = static_cast(b); + return 1 + (ac - 1) / bc; +} + +template > +constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { + static_assert(std::is_integral_v, "A must be an integer type"); + static_assert(std::is_integral_v, "B must be an integer type"); + C ac = static_cast(a); + C bc = static_cast(b); + return {ceilIntegerDivide(ac, bc), ac % bc}; +} + +template +bool isVectorShape(mlir::ArrayRef shape) { + return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); +} + +template +bool isMatrixShape(mlir::ArrayRef shape) { + return shape.size() == 2; +} + +template +bool isHVectorShape(mlir::ArrayRef shape) { + return shape.size() == 2 && shape[0] == 1; +} + +inline auto getTensorShape(mlir::Value tensor) { + return mlir::cast(tensor.getType()).getShape(); +} + +inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) { + auto lhsType = mlir::dyn_cast(lhs.getType()); + auto rhsType = mlir::dyn_cast(rhs.getType()); + return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() + && lhsType.getShape() == rhsType.getShape(); +} + +bool hasStaticPositiveShape(mlir::ArrayRef shape); + +bool hasStaticPositiveShape(mlir::RankedTensorType type); + +int64_t getStaticShapeElementCount(mlir::ArrayRef shape); + +llvm::SmallVector permuteShape(mlir::ArrayRef shape, mlir::ArrayRef permutation); + +llvm::SmallVector invertPermutation(mlir::ArrayRef permutation); + +mlir::FailureOr> getTransposePermutationChecked(std::optional permAttr, + int64_t rank); + +llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank); + +llvm::SmallVector getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank); + +llvm::SmallVector getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef shape); + } // namespace onnx_mlir diff --git a/src/PIM/Common/LabeledList.hpp b/src/PIM/Common/LabeledList.hpp deleted file mode 100644 index b3787e9..0000000 --- a/src/PIM/Common/LabeledList.hpp +++ /dev/null @@ -1,315 +0,0 @@ -#pragma once - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ilist_node.h" -#include "llvm/ADT/simple_ilist.h" - -#include -#include -#include -#include - -namespace onnx_mlir { - -template -class LabeledList; - -template -class LabeledListNode : public llvm::ilist_node { - friend class LabeledList; - -public: - using Label = uint64_t; - - LabeledListNode() = default; - LabeledListNode(const LabeledListNode&) = delete; - LabeledListNode(LabeledListNode&&) = default; - LabeledListNode& operator=(LabeledListNode&&) = delete; - - ~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); } - - bool isLinked() const { return owner_ != nullptr; } - Label getOrderLabel() const { return label; } - - friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt) { return lft.label < rgt.label; } - -private: - const void* owner_ = nullptr; - Label label = 0; -}; - -template -class LabeledList { - - using Label = typename NodeT::Label; - - static constexpr Label kLowerSentinel = 0; - static constexpr Label kUpperSentinel = std::numeric_limits, "A must be an integer type"); - static_assert(std::is_integral_v, "B must be an integer type"); - C ac = static_cast(a); - C bc = static_cast(b); - return 1 + (ac - 1) / bc; -} - -template > -constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { - static_assert(std::is_integral_v, "A must be an integer type"); - static_assert(std::is_integral_v, "B must be an integer type"); - C ac = static_cast(a); - C bc = static_cast(b); - return {ceilIntegerDivide(ac, bc), ac % bc}; -} - -template -bool isVectorShape(mlir::ArrayRef shape) { - return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); -} - -template -bool isMatrixShape(mlir::ArrayRef shape) { - return shape.size() == 2; -} - -template -bool isHVectorShape(mlir::ArrayRef shape) { - return shape.size() == 2 && shape[0] == 1; -} - -inline auto getTensorShape(mlir::Value tensor) { - return mlir::cast(tensor.getType()).getShape(); -} - -inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) { - auto lhsType = mlir::dyn_cast(lhs.getType()); - auto rhsType = mlir::dyn_cast(rhs.getType()); - return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() - && lhsType.getShape() == rhsType.getShape(); -} - -bool hasStaticPositiveShape(mlir::ArrayRef shape); - -bool hasStaticPositiveShape(mlir::RankedTensorType type); - -int64_t getStaticShapeElementCount(mlir::ArrayRef shape); - -llvm::SmallVector permuteShape(mlir::ArrayRef shape, mlir::ArrayRef permutation); - -llvm::SmallVector invertPermutation(mlir::ArrayRef permutation); - -mlir::FailureOr> getTransposePermutationChecked(std::optional permAttr, - int64_t rank); - -llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank); - -llvm::SmallVector getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank); - -llvm::SmallVector getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef shape); - /// Slices a statically shaped tensor along one axis into contiguous pieces of /// at most `sliceSize` elements. llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index a76fd6d..7b7d992 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -26,6 +26,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; -struct ConvLoweringState { - Value x; - Value w; - Value b; - RankedTensorType xType; - RankedTensorType wType; - RankedTensorType outType; - int64_t batchSize; - int64_t numChannelsIn; - int64_t xHeight; - int64_t xWidth; - int64_t numChannelsOut; - int64_t wHeight; - int64_t wWidth; - int64_t outHeight; - int64_t outWidth; - int64_t group; - int64_t numChannelsInPerGroup; - int64_t numChannelsOutPerGroup; - int64_t padHeightBegin; - int64_t padHeightEnd; - int64_t padWidthBegin; - int64_t padWidthEnd; - int64_t strideHeight; - int64_t strideWidth; - int64_t dilationHeight; - int64_t dilationWidth; - bool hasBias; -}; - -struct ConvGeometry { - int64_t batchSize; - int64_t numChannelsIn; - int64_t xHeight; - int64_t xWidth; - int64_t numChannelsOut; - int64_t wHeight; - int64_t wWidth; - int64_t outHeight; - int64_t outWidth; - int64_t group; - int64_t numChannelsInPerGroup; - int64_t numChannelsOutPerGroup; - int64_t k; - int64_t c; - int64_t p; - int64_t xbarSize; - int64_t pack; - uint64_t im2colElements; - bool hasBias; - bool isDepthwise; -}; - struct ConvLoweringDecision { PimConvLoweringType strategy; std::string reason; @@ -108,19 +56,6 @@ struct PreparedConvInput { RankedTensorType type; }; -struct RowInterval { - int64_t begin = 0; - int64_t end = 0; -}; - -struct ConvRowDemand { - RowInterval outputRows; - RowInterval neededInputRows; - RowInterval acquiredInputRows; - int64_t topHaloRows = 0; - int64_t bottomHaloRows = 0; -}; - struct ConvStrategyEstimate { uint64_t estimatedMvmCount = 0; uint64_t estimatedReductionVAddCount = 0; @@ -291,9 +226,6 @@ static FailureOr createRowStripPackedRows(Value rows, PatternRewriter& rewriter, Location loc); -static bool -isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup); -static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor); static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) { @@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo, return estimate; } -static ConvGeometry buildConvGeometry(const ConvLoweringState& state) { - ConvGeometry geo { - state.batchSize, - state.numChannelsIn, - state.xHeight, - state.xWidth, - state.numChannelsOut, - state.wHeight, - state.wWidth, - state.outHeight, - state.outWidth, - state.group, - state.numChannelsInPerGroup, - state.numChannelsOutPerGroup, - state.numChannelsInPerGroup * state.wHeight * state.wWidth, - state.numChannelsOutPerGroup, - state.batchSize * state.outHeight * state.outWidth, - static_cast(crossbarSize.getValue()), - 1, - 0, - state.hasBias, - isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup), - }; - geo.pack = std::max(1, geo.xbarSize / std::max(geo.k, geo.c)); - geo.im2colElements = static_cast(std::max(0, geo.p)) * static_cast(std::max(0, geo.k)); - return geo; -} - static std::string formatShape(ArrayRef dims) { std::string text; llvm::raw_string_ostream os(text); @@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user, return std::nullopt; } -static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, - int64_t inputHeight, - int64_t kernelH, - int64_t strideH, - int64_t dilationH, - int64_t padTop) { - const int64_t rawBegin = outputRows.begin * strideH - padTop; - const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1; - return {std::max(0, rawBegin), std::min(inputHeight, rawEnd)}; -} - static bool covers(RowInterval acquired, RowInterval needed) { return acquired.begin <= needed.begin && acquired.end >= needed.end; } -static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) { - const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin; - const int64_t rawEnd = - (outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1; - RowInterval neededInputRows = computeConvInputRowsForOutputRows( - outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin); - ConvRowDemand demand; - demand.outputRows = outputRows; - demand.neededInputRows = neededInputRows; - demand.acquiredInputRows = neededInputRows; - demand.topHaloRows = std::max(0, -rawBegin); - demand.bottomHaloRows = std::max(0, rawEnd - state.xHeight); - return demand; -} - static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) { if (state.batchSize != 1) { failureReason = "unsupported_batch"; @@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp, rewriteConvLoweringReport(reportEntries); } -static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) { - const uint64_t patchElements = static_cast(std::max(1, geo.k)); - uint64_t chunkPositions = std::max(1, pimConvIm2colMaxElements / patchElements); - chunkPositions = std::min(chunkPositions, static_cast(std::max(1, geo.p))); - chunkPositions = std::min(chunkPositions, std::max(1, pimConvStreamChunkPositions)); - - if (packFactor > 1 && chunkPositions > static_cast(packFactor)) { - chunkPositions -= chunkPositions % static_cast(packFactor); - chunkPositions = std::max(chunkPositions, static_cast(packFactor)); - } - return std::max(1, chunkPositions); -} - static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) @@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location }); } -static bool -isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) { - return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0; -} - static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) { assert(value > 0 && "expected positive value"); limit = std::min(value, limit); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.cpp new file mode 100644 index 0000000..4b1fa71 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.cpp @@ -0,0 +1,77 @@ +#include "ConvGeometry.hpp" + +#include + +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + +namespace onnx_mlir { + +bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) { + return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0; +} + +ConvGeometry buildConvGeometry(const ConvLoweringState& state) { + ConvGeometry geo { + state.batchSize, + state.numChannelsIn, + state.xHeight, + state.xWidth, + state.numChannelsOut, + state.wHeight, + state.wWidth, + state.outHeight, + state.outWidth, + state.group, + state.numChannelsInPerGroup, + state.numChannelsOutPerGroup, + state.numChannelsInPerGroup * state.wHeight * state.wWidth, + state.numChannelsOutPerGroup, + state.batchSize * state.outHeight * state.outWidth, + static_cast(crossbarSize.getValue()), + 1, + 0, + state.hasBias, + isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup), + }; + geo.pack = std::max(1, geo.xbarSize / std::max(geo.k, geo.c)); + geo.im2colElements = static_cast(std::max(0, geo.p)) * static_cast(std::max(0, geo.k)); + return geo; +} + +uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) { + const uint64_t patchElements = static_cast(std::max(1, geo.k)); + uint64_t chunkPositions = std::max(1, pimConvIm2colMaxElements / patchElements); + chunkPositions = std::min(chunkPositions, static_cast(std::max(1, geo.p))); + chunkPositions = std::min(chunkPositions, std::max(1, pimConvStreamChunkPositions)); + + if (packFactor > 1 && chunkPositions > static_cast(packFactor)) { + chunkPositions -= chunkPositions % static_cast(packFactor); + chunkPositions = std::max(chunkPositions, static_cast(packFactor)); + } + return std::max(1, chunkPositions); +} + +RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state) { + const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin; + const int64_t rawEnd = + (outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1; + return {std::max(0, rawBegin), std::min(state.xHeight, rawEnd)}; +} + +ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) { + ConvRowDemand demand; + demand.outputRows = outputRows; + demand.neededInputRows = computeConvInputRowsForOutputRows(outputRows, state); + demand.acquiredInputRows = demand.neededInputRows; + + const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin; + const int64_t rawEnd = + (outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1; + demand.topHaloRows = std::max(0, -rawBegin); + demand.bottomHaloRows = std::max(0, rawEnd - state.xHeight); + demand.acquiredInputRows = demand.neededInputRows; + return demand; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp new file mode 100644 index 0000000..60564c6 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp @@ -0,0 +1,86 @@ +#pragma once + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" + +#include + +namespace onnx_mlir { + +struct ConvLoweringState { + mlir::Value x; + mlir::Value w; + mlir::Value b; + mlir::RankedTensorType xType; + mlir::RankedTensorType wType; + mlir::RankedTensorType outType; + int64_t batchSize; + int64_t numChannelsIn; + int64_t xHeight; + int64_t xWidth; + int64_t numChannelsOut; + int64_t wHeight; + int64_t wWidth; + int64_t outHeight; + int64_t outWidth; + int64_t group; + int64_t numChannelsInPerGroup; + int64_t numChannelsOutPerGroup; + int64_t padHeightBegin; + int64_t padHeightEnd; + int64_t padWidthBegin; + int64_t padWidthEnd; + int64_t strideHeight; + int64_t strideWidth; + int64_t dilationHeight; + int64_t dilationWidth; + bool hasBias; +}; + +struct ConvGeometry { + int64_t batchSize; + int64_t numChannelsIn; + int64_t xHeight; + int64_t xWidth; + int64_t numChannelsOut; + int64_t wHeight; + int64_t wWidth; + int64_t outHeight; + int64_t outWidth; + int64_t group; + int64_t numChannelsInPerGroup; + int64_t numChannelsOutPerGroup; + int64_t k; + int64_t c; + int64_t p; + int64_t xbarSize; + int64_t pack; + uint64_t im2colElements; + bool hasBias; + bool isDepthwise; +}; + +struct RowInterval { + int64_t begin = 0; + int64_t end = 0; +}; + +struct ConvRowDemand { + RowInterval outputRows; + RowInterval neededInputRows; + RowInterval acquiredInputRows; + int64_t topHaloRows = 0; + int64_t bottomHaloRows = 0; +}; + +bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup); + +ConvGeometry buildConvGeometry(const ConvLoweringState& state); + +uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor); + +RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state); + +ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 5b8a7c4..00d5b2b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -87,28 +87,6 @@ static Value createGemmBatchHOffset(Value lane, rewriter.getInsertionBlock()->getParentOp()); } -static Value -createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { - auto sourceType = cast(value.getType()); - SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); - SmallVector highPads; - highPads.reserve(sourceType.getRank()); - for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) - highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); - - auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); - auto* padBlock = new Block(); - for (int64_t i = 0; i < sourceType.getRank(); ++i) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - auto zero = getOrCreateConstant( - rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); - tensor::YieldOp::create(rewriter, loc, zero); - rewriter.setInsertionPointAfter(padOp); - return padOp.getResult(); -} - static FailureOr materializePaddedConstantMatrix(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, @@ -232,22 +210,6 @@ static Value extractATile( return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult(); } -static Value createPaddedInputCompute(Value input, - RankedTensorType paddedInputType, - ConversionPatternRewriter& rewriter, - Location loc) { - auto inputType = cast(input.getType()); - if (inputType == paddedInputType) - return input; - - auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { - Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc); - spatial::SpatYieldOp::create(rewriter, loc, paddedInput); - }); - - return computeOp.getResult(0); -} - static FailureOr createVmmBatch(Value a, Value b, RankedTensorType aType, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index e74700d..b120b9e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -255,42 +255,6 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati return createONNXTranspose(resultType, {0, 2, 1}); } -static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { - auto sourceType = cast(value.getType()); - SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); - SmallVector highPads; - highPads.reserve(sourceType.getRank()); - for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) - highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); - - auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); - auto* padBlock = new Block(); - for (int64_t i = 0; i < sourceType.getRank(); ++i) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - auto zero = getOrCreateConstant( - rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); - tensor::YieldOp::create(rewriter, loc, zero); - rewriter.setInsertionPointAfter(padOp); - return padOp.getResult(); -} - -static Value createPaddedBatchedInputCompute(Value input, - RankedTensorType paddedInputType, - PatternRewriter& rewriter, - Location loc) { - auto inputType = cast(input.getType()); - if (inputType == paddedInputType) - return input; - - auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { - Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc); - spatial::SpatYieldOp::create(rewriter, loc, paddedInput); - }); - return computeOp.getResult(0); -} - static FailureOr materializePaddedBatchedWeight(Value value, ArrayRef sourceBatchShape, ArrayRef targetBatchShape, @@ -1055,7 +1019,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { auto paddedRhs = materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter); if (succeeded(paddedRhs)) { - Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc); + Value paddedLhs = createPaddedInputCompute(plan.lhs, paddedLhsType, rewriter, loc); const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices; auto partialPiecesType = RankedTensorType::get({laneCount, static_cast(crossbarSize.getValue())}, shapeInfo->outType.getElementType()); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 8f32848..037f0a3 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -29,100 +29,6 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) { }); } -static bool isMaterializableExternalTensorOp(Operation* op) { - return isa(op); -} - -//TODO REMOVE THIS UGLY FIX -//TODO: Remove this helper once compute_batch external tensor captures are -// fixed at the producer side. -// -// This function is a temporary SpatialToPim repair path. It clones selected -// external tensor producers, such as channel_receive and tensor view/slice ops, -// into the new pim.core_batch body when the old spat.compute_batch body refers -// to tensor values defined outside the batch. -// -// The real invariant should be stronger: -// -// A spat.compute_batch body must not capture external tensor values. -// Every tensor used inside the body must be either: -// - a compute_batch block argument, -// - defined inside the compute_batch body, -// - or a legal constant-like value. -// -// If this invariant is violated, the responsible producer, most likely merge -// schedule materialization, should emit verifier-clean Spatial IR instead of -// relying on SpatialToPim to clone external producer chains later. -// -// After that producer-side fix: -// 1. remove isMaterializableExternalTensorOp, -// 2. remove materializeExternalTensorValue, -// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external -// tensor operand, -// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected -// before SpatialToPim. -// -// Be careful not to replace every external tensor capture with a normal -// compute_batch input blindly: host-backed tensors and explicit inter-core -// communication have different semantics. In particular, channel_receive-like -// values should be materialized through the communication model, not silently -// treated as host inputs. -static FailureOr materializeExternalTensorValue(IRRewriter& rewriter, - Location loc, - Block& oldBlock, - Value value, - IRMapping& mapper) { - if (mapper.contains(value)) - return mapper.lookup(value); - - if (!isa(value.getType())) - return value; - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || definingOp->hasTrait()) - return failure(); - - if (definingOp->getBlock() == &oldBlock) - return failure(); - - if (!isMaterializableExternalTensorOp(definingOp)) - return failure(); - - for (Value operand : definingOp->getOperands()) { - FailureOr materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper); - if (succeeded(materializedOperand)) - mapper.map(operand, *materializedOperand); - } - - Operation* cloned = rewriter.clone(*definingOp, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - - return mapper.lookup(value); -} - -static FailureOr> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, - size_t& fallbackCoreId) { - if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) - return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); - - SmallVector coreIds; - coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); - for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) { - auto checkedCoreId = - pim::checkedI32(static_cast(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id"); - if (failed(checkedCoreId)) - return failure(); - coreIds.push_back(*checkedCoreId); - ++fallbackCoreId; - } - return coreIds; -} - static FailureOr getDirectReturnOperandIndex(OpResult result) { if (!result.hasOneUse()) return failure(); @@ -386,7 +292,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul "resultful compute_batch lowering currently requires a spat.in_parallel terminator"); } - auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); + auto coreIds = getRequiredScheduledBatchCoreIds(computeBatchOp, "spatial compute_batch core id"); if (failed(coreIds)) return failure(); SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); @@ -638,9 +544,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul if (definingOp && definingOp->hasTrait()) continue; - if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper))) - continue; - InFlightDiagnostic diagnostic = computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering"); diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex; diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index e4300ea..fba107e 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -9,6 +9,7 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" @@ -141,17 +142,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite } } -static FailureOr getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) { - if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id"); - auto checkedCoreId = - pim::checkedI32(static_cast(fallbackCoreId), computeOp, "fallback spatial compute core id"); - if (failed(checkedCoreId)) - return failure(); - ++fallbackCoreId; - return *checkedCoreId; -} - static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { @@ -311,7 +301,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCom if (!computeOp.getWeights().empty()) computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); rewriter.setInsertionPointAfter(computeOp); - auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId); + auto checkedCoreId = getRequiredScheduledCoreId(computeOp, "spatial compute core id"); if (failed(checkedCoreId)) return failure(); auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast(*checkedCoreId), "pim core id"); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index d87c459..5549eb8 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -44,121 +44,29 @@ using namespace pim; namespace onnx_mlir { -static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { - auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType(); - auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType())); - - for (auto globalOp : moduleOp.getOps()) { - if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue()) - continue; - if (dyn_cast(*globalOp.getInitialValue()) == zeroAttr) - return globalOp; - } - - std::string nameStem; - llvm::raw_string_ostream nameStream(nameStem); - nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements(); - nameStream.flush(); - - std::string symbolName = nameStem; - unsigned suffix = 0; - while (SymbolTable::lookupSymbolIn(moduleOp, symbolName)) - symbolName = (nameStem + "_" + Twine(suffix++)).str(); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - return memref::GlobalOp::create(rewriter, - loc, - rewriter.getStringAttr(symbolName), - rewriter.getStringAttr("private"), - TypeAttr::get(memRefType), - zeroAttr, - rewriter.getUnitAttr(), - IntegerAttr {}); -} - -static FailureOr createZeroedDeviceHVector(IRRewriter& rewriter, - Location loc, - RankedTensorType tensorType, - OperationFolder& constantFolder) { - auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); - auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); - auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); - auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0); - auto byteSize = - pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size"); - if (failed(byteSize)) - return failure(); - auto sizeAttr = - pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size"); - if (failed(sizeAttr)) - return failure(); - return PimMemCopyHostToDevOp::create( - rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr) - .getOutput(); -} - -static bool isHostBackedMemRefValue(Value value) { - while (Operation* definingOp = value.getDefiningOp()) { - if (auto subviewOp = dyn_cast(definingOp)) { - value = subviewOp.getSource(); - continue; - } - if (auto castOp = dyn_cast(definingOp)) { - value = castOp.getSource(); - continue; - } - if (auto collapseOp = dyn_cast(definingOp)) { - value = collapseOp.getSrc(); - continue; - } - if (auto expandOp = dyn_cast(definingOp)) { - value = expandOp.getSrc(); - continue; - } - return isa(definingOp); - } - return false; -} - -static bool isHostBackedTensorValue(Value value) { - while (Operation* definingOp = value.getDefiningOp()) { - if (auto extractSliceOp = dyn_cast(definingOp)) { - auto sourceType = dyn_cast(extractSliceOp.getSource().getType()); - auto resultType = dyn_cast(extractSliceOp.getResult().getType()); - if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) - return false; - if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(), - extractSliceOp.getMixedOffsets(), - extractSliceOp.getStaticSizes(), - extractSliceOp.getStaticStrides())) { - return false; - } - value = extractSliceOp.getSource(); - continue; - } - if (auto collapseOp = dyn_cast(definingOp)) { - value = collapseOp.getSrc(); - continue; - } - if (auto expandOp = dyn_cast(definingOp)) { - value = expandOp.getSrc(); - continue; - } - if (auto castOp = dyn_cast(definingOp)) { - value = castOp.getSource(); - continue; - } - if (auto toTensorOp = dyn_cast(definingOp)) - return isHostBackedMemRefValue(toTensorOp.getBuffer()); - return false; - } - return false; -} - static FailureOr -padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { +createZeroPaddedTensor(IRRewriter& rewriter, Location loc, Value value, RankedTensorType resultType) { + auto sourceType = cast(value.getType()); + SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); + SmallVector highPads; + highPads.reserve(sourceType.getRank()); + for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) + highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); + + auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); + auto* padBlock = new Block(); + for (int64_t i = 0; i < sourceType.getRank(); ++i) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = getOrCreateConstant( + rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); + tensor::YieldOp::create(rewriter, loc, zero); + rewriter.setInsertionPointAfter(padOp); + return padOp.getResult(); +} + +static FailureOr padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) { auto vectorType = cast(vector.getType()); ArrayRef shape = vectorType.getShape(); assert(isHVectorShape(shape) && "expected a horizontal vector"); @@ -169,26 +77,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, auto paddedType = RankedTensorType::get( {shape[0], static_cast(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); - auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); - if (failed(zeroed)) - return failure(); - Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0); - auto byteSize = - pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size"); - if (failed(byteSize)) - return failure(); - auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size"); - if (failed(sizeAttr)) - return failure(); - if (isHostBackedTensorValue(vector)) { - return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr) - .getOutput(); - } - return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput(); + return createZeroPaddedTensor(rewriter, loc, vector, paddedType); } void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { - coreId = 0; outputTensors.clear(); operationsToRemove.clear(); ModuleOp moduleOp = getOperation(); @@ -362,7 +254,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { - OperationFolder constantFolder(funcOp.getContext()); bool hasFailure = false; funcOp.walk([&](PimVMMOp vmmOp) { auto outputType = cast(vmmOp.getOutput().getType()); @@ -371,7 +262,7 @@ LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func: assert(outputShape[1] <= static_cast(crossbarSize) && "output width must fit in one crossbar"); rewriter.setInsertionPoint(vmmOp); - auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); + auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput()); if (failed(paddedInput)) { hasFailure = true; return WalkResult::interrupt(); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp index 3c82c60..702bb7c 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp @@ -36,7 +36,6 @@ private: using OutputTensorFactory = std::function; llvm::SmallVector outputTensors; - size_t coreId = 0; llvm::SmallVector operationsToRemove; mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 9669447..01625f1 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -8,7 +8,9 @@ add_pim_library(SpatialOps SpatialOpsCanonicalization.cpp ${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp + Transforms/MergeComputeNodes/HostOutputFinalization.cpp Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp + Transforms/MergeComputeNodes/ProjectedFragments.cpp Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.cpp new file mode 100644 index 0000000..ba683c2 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.cpp @@ -0,0 +1,134 @@ +#include "HostOutputFinalization.hpp" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" + +#include "MaterializedClassState.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir::spatial { + +LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { + if (state.pendingProjectedHostOutputFragments.empty()) + return success(); + + DenseMap> byOutput; + for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) + byOutput[fragment.originalOutput].push_back(&fragment); + + SmallVector outputs; + outputs.reserve(byOutput.size()); + + auto returnOp = dyn_cast(state.func.getBody().front().getTerminator()); + if (!returnOp) + return state.func.emitError("expected func.return terminator while finalizing projected host output fragments"); + + DenseSet seenOutputs; + for (Value returned : returnOp.getOperands()) { + if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second) + continue; + outputs.push_back(returned); + } + if (outputs.size() != byOutput.size()) + return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs"); + + for (Value originalOutput : outputs) { + if (isa_and_present(originalOutput.getDefiningOp())) { + return state.func.emitError( + "projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result"); + } + + auto resultType = dyn_cast(originalOutput.getType()); + if (!resultType || !resultType.hasStaticShape()) + return state.func.emitError("projected host output must have static ranked tensor type"); + + SmallVector& fragments = byOutput[originalOutput]; + llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, + const PendingProjectedHostOutputFragment* rhs) { + if (lhs->sourceClass != rhs->sourceClass) + return lhs->sourceClass < rhs->sourceClass; + if (lhs->publicationResultIndex != rhs->publicationResultIndex) + return lhs->publicationResultIndex < rhs->publicationResultIndex; + if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) + return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; + return std::lexicographical_compare(lhs->offsets.begin(), + lhs->offsets.end(), + rhs->offsets.begin(), + rhs->offsets.end()); + }); + + state.rewriter.setInsertionPoint(returnOp); + Location loc = fragments.front()->loc; + SmallVector blueprintOperands; + SmallVector fragmentOperandIndices; + SmallVector fragmentSourceOffsets; + SmallVector flatOffsets; + SmallVector flatSizes; + SmallVector flatStrides; + DenseMap operandIndicesByValue; + + for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { + if (fragmentRecord->sourceClass >= state.classes.size()) + return state.func.emitError("projected host output fragment references an invalid source class"); + + MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; + if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) { + return sourceClass.op->emitError("projected host output fragment references an invalid publication result") + << " sourceClass=" << sourceClass.id + << " resultIndex=" << fragmentRecord->publicationResultIndex + << " resultCount=" << sourceClass.op->getNumResults(); + } + + Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex); + auto [operandIt, inserted] = + operandIndicesByValue.try_emplace(operand, static_cast(blueprintOperands.size())); + if (inserted) + blueprintOperands.push_back(operand); + fragmentOperandIndices.push_back(operandIt->second); + fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset); + llvm::append_range(flatOffsets, fragmentRecord->offsets); + llvm::append_range(flatSizes, fragmentRecord->sizes); + llvm::append_range(flatStrides, fragmentRecord->strides); + + auto operandType = dyn_cast(operand.getType()); + if (!operandType || !operandType.hasStaticShape()) + return state.func.emitError("projected host output assembly requires static ranked tensor operands"); + } + + if (blueprintOperands.empty()) + return state.func.emitError("missing projected host output fragments"); + + Value input = blueprintOperands.front(); + ValueRange extraFragments = ValueRange(blueprintOperands).drop_front(); + auto blueprint = SpatBlueprintOp::create( + state.rewriter, + loc, + resultType, + input, + extraFragments, + state.rewriter.getStringAttr("nchw"), + state.rewriter.getStringAttr("fragmented"), + state.rewriter.getDenseI64ArrayAttr(flatOffsets), + state.rewriter.getDenseI64ArrayAttr(flatSizes), + state.rewriter.getStringAttr("identity"), + state.rewriter.getStringAttr("fragment_assembly"), + state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices), + state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets), + state.rewriter.getDenseI64ArrayAttr(flatStrides), + state.rewriter.getStringAttr("disjoint"), + state.rewriter.getStringAttr("complete")); + + state.hostReplacements[originalOutput] = blueprint.getOutput(); + } + + return success(); +} + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.hpp new file mode 100644 index 0000000..29a6fa2 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/HostOutputFinalization.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "mlir/Support/LogicalResult.h" + +namespace onnx_mlir::spatial { + +struct MaterializerState; + +mlir::LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state); + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 31219e5..e3b6c71 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -24,6 +24,11 @@ #include #include "MaterializeMergeSchedule.hpp" +#include "MaterializedClassState.hpp" +#include "MergeMessages.hpp" +#include "MergeScheduleKeys.hpp" +#include "HostOutputFinalization.hpp" +#include "ProjectedFragments.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" @@ -39,348 +44,6 @@ namespace onnx_mlir { namespace spatial { namespace { -using CpuId = size_t; -using ClassId = size_t; -using SlotId = size_t; - -static FailureOr getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) { - return pim::checkedI32(static_cast(cpu), anchor, fieldName); -} - -static FailureOr> -getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) { - SmallVector coreIds; - coreIds.reserve(cpus.size()); - for (CpuId cpu : cpus) { - auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); - if (failed(checkedCoreId)) - return failure(); - coreIds.push_back(*checkedCoreId); - } - return coreIds; -} - -struct MessageVector { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - size_t size() const { return channelIds.size(); } - bool empty() const { return channelIds.empty(); } - - LogicalResult verify(Operation* anchor) const { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) - return anchor->emitError("message metadata is inconsistent"); - return success(); - } - - void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { - channelIds.push_back(channelId); - sourceCoreIds.push_back(sourceCoreId); - targetCoreIds.push_back(targetCoreId); - } - - void append(ArrayRef channels, ArrayRef sources, ArrayRef targets) { - assert(channels.size() == sources.size() && "channel/source count mismatch"); - assert(channels.size() == targets.size() && "channel/target count mismatch"); - llvm::append_range(channelIds, channels); - llvm::append_range(sourceCoreIds, sources); - llvm::append_range(targetCoreIds, targets); - } - - MessageVector slice(size_t offset, size_t count) const { - MessageVector result; - result.append(ArrayRef(channelIds).slice(offset, count), - ArrayRef(sourceCoreIds).slice(offset, count), - ArrayRef(targetCoreIds).slice(offset, count)); - return result; - } -}; - -struct ProducerKey { - ComputeInstance instance; - size_t resultIndex = 0; - - bool operator==(const ProducerKey& other) const { - return instance == other.instance && resultIndex == other.resultIndex; - } -}; - -struct ProducerKeyInfo { - static ProducerKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; - } - - static ProducerKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ProducerKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); - } - - static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } -}; - -struct SameClassConsumerLookupKey { - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - ClassId classId = 0; - - bool operator==(const SameClassConsumerLookupKey& other) const { - return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; - } -}; - -struct SameClassConsumerLookupKeyInfo { - static SameClassConsumerLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static SameClassConsumerLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static unsigned getHashValue(const SameClassConsumerLookupKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); - } - - static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { - return lhs == rhs; - } -}; - -struct WholeBatchAssemblyLookupKey { - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - ClassId classId = 0; - - bool operator==(const WholeBatchAssemblyLookupKey& other) const { - return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; - } -}; - -struct WholeBatchAssemblyLookupKeyInfo { - static WholeBatchAssemblyLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static WholeBatchAssemblyLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); - } - - static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { - return lhs == rhs; - } -}; - -using ClassSlotKey = std::pair; - -struct MaterializedClass { - ClassId id = 0; - SmallVector cpus; - Operation* op = nullptr; - Block* body = nullptr; - bool isBatch = false; - - DenseMap cpuToLane; - SmallVector weights; - SmallVector inputs; - SmallVector hostOutputs; - DenseMap publicationOutputToResultIndex; - DenseMap weightArgs; - DenseMap inputArgs; - DenseMap hostOutputToResultIndex; -}; - -struct PackedScalarRunSlot { - SmallVector keys; -}; - -enum class PackedScalarRunKind { - Materialized, - DeferredReceive, - DeferredLocalCompute -}; - -struct PackedScalarRunValue { - ClassId targetClass = 0; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - PackedScalarRunKind kind = PackedScalarRunKind::Materialized; - - Value packed; - - RankedTensorType fragmentType; - SmallVector slots; - MessageVector messages; -}; - -struct IndexedBatchRunValue { - ClassId targetClass = 0; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - Value packed; - RankedTensorType fragmentType; - SmallVector slots; - MessageVector messages; -}; - -struct LogicalSlotRange { - SlotId start = 0; - SlotId count = 0; -}; - -struct MaterializationRunSlot { - SmallVector peers; -}; - -using MaterializationRun = SmallVector; - -struct OutputDestinationGroup { - SmallVector resultIndices; - SmallVector destinationClasses; -}; - -struct BatchRunSendPlan { - size_t resultIndex = 0; - ClassId destinationClass = 0; - MessageVector messages; -}; - -struct ProjectedBatchInputKey { - Operation* consumerOp = nullptr; - unsigned inputIndex = 0; - - bool operator==(const ProjectedBatchInputKey& other) const { - return consumerOp == other.consumerOp && inputIndex == other.inputIndex; - } -}; - -struct ProjectedBatchInputKeyInfo { - static ProjectedBatchInputKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; - } - - static ProjectedBatchInputKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ProjectedBatchInputKey& key) { - return llvm::hash_combine(key.consumerOp, key.inputIndex); - } - - static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } -}; - -struct ProjectedFragmentLayout { - RankedTensorType fragmentType; - SmallVector fragmentShape; - unsigned fragmentsPerLogicalSlot = 1; - unsigned payloadFragmentCount = 1; - SmallVector loopLowerBounds; - SmallVector loopSteps; - SmallVector loopTripCounts; -}; - -struct StaticProjectedLoopInfo { - BlockArgument iv; - int64_t lowerBound = 0; - int64_t step = 1; - int64_t tripCount = 1; -}; - -struct ProjectedTransferDescriptor { - ProjectedBatchInputKey inputKey; - Operation* extractOp = nullptr; - ProjectedFragmentLayout layout; - RankedTensorType payloadType; - SmallVector, 16> fragmentOffsets; - SmallVector, 4> fragmentOffsetsByDim; -}; - -struct ProjectedExtractReplacement { - Value payload; - ProjectedFragmentLayout layout; -}; - -struct PendingProjectedHostOutputFragment { - Value originalOutput; - ClassId sourceClass = 0; - ProducerKey producerKey; - unsigned publicationResultIndex = 0; - int64_t sourceFragmentOrdinal = 0; - int64_t sourceElementOffset = 0; - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - uint32_t sourceLane = 0; - Location loc; -}; - -enum class TensorDemandActionKind { - DestinationFanout, - SameClassIndexedFragment, - TerminalBlueprintPublication, - WholeTensorBarrier -}; - -enum class WholeTensorBarrierReason { - FunctionReturnWithoutBlueprint, - DenseLogicalConsumer -}; - -struct TensorDemandAction { - TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout; - std::optional destinationClass; - std::optional barrierReason; -}; - -struct RunOutputDemand { - size_t resultIndex = 0; - Value originalOutput; - RankedTensorType fragmentType; - SmallVector actions; -}; - -struct CompactRunPlan { - SmallVector outputs; -}; - -enum class BatchInputDemandKind { - LaneFragment, - ProjectedFragment, - WholeTensorBarrier -}; - -struct BatchInputDemand { - BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment; - std::optional wholeTensorProducer; -}; - -struct AffineProjectedInputSliceMatch { - tensor::ExtractSliceOp extract; - RankedTensorType sourceType; - RankedTensorType fragmentType; - SmallVector fragmentShape; - SmallVector offsets; - SmallVector loops; -}; - -struct CloneIndexingContext { - std::optional runSlotIndex; - std::optional projectionSlotIndex; -}; - -struct MaterializerState; FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, @@ -445,112 +108,6 @@ FailureOr materializeProjectedWholeBatchExtractReplacement(MaterializerSt ProducerKey producer, IRMapping* mapper = nullptr); -class AvailableValueStore { -public: - struct ExactBatchFragmentRecord { - ProducerKey key; - Value value; - }; - - void record(ProducerKey key, ClassId classId, Value value) { - exactValues[key][classId] = value; - - auto batch = dyn_cast_or_null(key.instance.op); - if (!batch || key.instance.laneCount == 0) - return; - - WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; - SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; - for (ExactBatchFragmentRecord& record : bucket) { - if (!(record.key == key)) - continue; - record.value = value; - return; - } - bucket.push_back({key, value}); - } - - void recordPackedRun(PackedScalarRunValue run) { - size_t runIndex = packedScalarRuns.size(); - packedScalarRuns.push_back(std::move(run)); - const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; - WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; - packedRunsByProducerResultClass[lookupKey].push_back(runIndex); - } - void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } - - std::optional lookupExact(ProducerKey key, ClassId classId) const; - - std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); - IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); - - ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { - auto it = packedRunsByProducerResultClass.find(key); - if (it == packedRunsByProducerResultClass.end()) - return {}; - return it->second; - } - - ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { - auto it = exactBatchFragmentsByProducerResultClass.find(key); - if (it == exactBatchFragmentsByProducerResultClass.end()) - return {}; - return it->second; - } - - PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } - -private: - std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); - - DenseMap, ProducerKeyInfo> exactValues; - SmallVector packedScalarRuns; - SmallVector indexedBatchRuns; - DenseMap, WholeBatchAssemblyLookupKeyInfo> - exactBatchFragmentsByProducerResultClass; - DenseMap, WholeBatchAssemblyLookupKeyInfo> - packedRunsByProducerResultClass; -}; - -struct MaterializerState { - func::FuncOp func; - const MergeScheduleResult& schedule; - IRRewriter rewriter; - OperationFolder constantFolder; - int64_t& nextChannelId; - SmallVector classes; - DenseMap cpuToClass; - DenseMap> logicalInstancesByCpu; - DenseMap scheduledInstanceToLogicalSlots; - DenseMap logicalInstanceToScheduledChunk; - DenseSet materializedLogicalSlots; - - DenseMap, ProducerKeyInfo> producerDestClasses; - DenseMap, SameClassConsumerLookupKeyInfo> - sameClassConsumerIndex; - DenseMap projectedInputMatches; - DenseSet nonProjectedInputs; - DenseMap liveExternalUseCache; - DenseMap> batchOutputFragmentTypesCache; - DenseMap, llvm::DenseMapInfo> computeInstanceOutputsCache; - DenseMap, ProducerKeyInfo> projectedTransfers; - DenseMap> projectedExtractReplacements; - AvailableValueStore availableValues; - DenseMap hostReplacements; - DenseMap hostOutputOwners; - SmallVector pendingProjectedHostOutputFragments; - DenseSet oldComputeOps; - - MaterializerState(func::FuncOp func, - const MergeScheduleResult& schedule, - int64_t& nextChannelId) - : func(func), - schedule(schedule), - rewriter(func.getContext()), - constantFolder(func.getContext()), - nextChannelId(nextChannelId) {} -}; - bool isConstantLike(Value value) { Operation* definingOp = value.getDefiningOp(); return definingOp && definingOp->hasTrait(); @@ -1264,17 +821,6 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) { return success(); } -void setInsertionPointForNewMaterializedOp(MaterializerState& state) { - Block& funcBlock = state.func.getBody().front(); - for (Operation& op : funcBlock) { - if (state.oldComputeOps.contains(&op)) { - state.rewriter.setInsertionPoint(&op); - return; - } - } - state.rewriter.setInsertionPointToEnd(&funcBlock); -} - BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { auto it = materializedClass.weightArgs.find(weight); if (it != materializedClass.weightArgs.end()) @@ -1408,19 +954,6 @@ FailureOr appendBatchPublicationResult(MaterializerState& state, // Materialized-class value localization helpers. // ----------------------------------------------------------------------------- -Region* getParentRegion(Value value) { - if (auto blockArg = dyn_cast(value)) - return blockArg.getOwner()->getParent(); - if (Operation* definingOp = value.getDefiningOp()) - return definingOp->getParentRegion(); - return nullptr; -} - -bool isDefinedInsideRegion(Value value, Region& region) { - Region* parentRegion = getParentRegion(value); - return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); -} - Operation* getEnclosingSpatialComputeLikeOp(Value value) { Block* block = nullptr; if (auto blockArg = dyn_cast(value)) @@ -2156,14 +1689,6 @@ FailureOr> buildProjectedFragmentOffsetsInClass(Mat return fragmentOffsets; } -SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { - SmallVector attrs; - attrs.reserve(values.size()); - for (int64_t value : values) - attrs.push_back(builder.getIndexAttr(value)); - return attrs; -} - Value createDim0InsertSlice( MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { auto fragmentType = cast(fragment.getType()); @@ -2284,6 +1809,8 @@ std::optional extractPackedProducerSlice(MaterializerState& state, return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); } +} // namespace + std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { auto producerIt = exactValues.find(key); if (producerIt == exactValues.end()) @@ -2296,6 +1823,32 @@ std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId c return valueIt->second; } +namespace { + +using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; +using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; + +SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run); +FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + int64_t itemCount, + IndexedFragmentBuilder buildFragment, + IndexedInsertOffsetBuilder buildOffset, + Location loc); +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices, + CloneIndexingContext indexing); +Value createIndexedChannelId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc); +Value createIndexedSourceCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc); +Value createIndexedTargetCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc); + Value getPackedSliceForRunIndex(MaterializerState& state, Operation* anchor, Value packed, @@ -2317,21 +1870,9 @@ Value getPackedSliceForDynamicRunIndex(MaterializerState& state, return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } -FailureOr createReceiveConcatLoop(MaterializerState& state, - MaterializedClass& targetClass, - RankedTensorType concatType, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc); - using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; -FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc); - bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { return run.kind == PackedScalarRunKind::DeferredLocalCompute; } @@ -2371,8 +1912,69 @@ FailureOr materializePackedScalarRunValue(MaterializerState& state, if (run.kind == PackedScalarRunKind::Materialized) return targetClass.op->emitError("materialized packed scalar run has no packed value"); - if (isDeferredLocalPackedScalarRun(run)) - return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc); + if (isDeferredLocalPackedScalarRun(run)) { + SmallVector keys = flattenPackedScalarRunKeys(run); + if (keys.empty()) + return failure(); + FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); + if (failed(packedType)) + return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); + + SmallVector sourceLanes; + sourceLanes.reserve(keys.size()); + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return failure(); + sourceLanes.push_back(key.instance.laneStart); + } + + SmallVector resultIndices {run.resultIndex}; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + keys.front().instance, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); + if (failed(produced) || produced->size() != 1) + return failure(); + + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + FailureOr next = + createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + run.packed = loop->results.front(); + return run.packed; + } if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) return failure(); @@ -2382,13 +1984,34 @@ FailureOr materializePackedScalarRunValue(MaterializerState& state, if (failed(fullPackedType)) return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); - auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, fullPackedType->getShape(), fullPackedType->getElementType()) + .getResult(); + auto packed = emitIndexedFragmentInsertLoop( + state, + targetClass, + init, + static_cast(run.messages.size()), + [&](Value index) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, index, loc); + return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value index) -> FailureOr { + return scaleIndexByDim0SizeInClass(state, targetClass, index, run.fragmentType.getDimSize(0), loc); + }, + loc); if (failed(packed)) return failure(); run.packed = *packed; return run.packed; } +} // namespace + std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { for (PackedScalarRunValue& run : packedScalarRuns) { if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) @@ -2488,6 +2111,8 @@ std::optional AvailableValueStore::lookup(MaterializerState& state, Produ return std::nullopt; } +namespace { + Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { SmallVector elements; elements.reserve(values.size()); @@ -2987,89 +2612,6 @@ isProjectedOffsetValue(Value value, Value laneArg, ArrayRef getConstantIndex(OpFoldResult value); -static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef loopTripCounts) { - unsigned fragmentsPerLogicalSlot = 1; - for (int64_t tripCount : loopTripCounts) { - assert(tripCount > 0 && "projected loop trip counts must be positive"); - fragmentsPerLogicalSlot *= static_cast(tripCount); - } - return fragmentsPerLogicalSlot; -} - -LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) { - if (!layout.fragmentType || layout.fragmentShape.empty()) - return anchor->emitError("projected fragment layout is missing fragment type metadata"); - if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) - return anchor->emitError("projected fragment layout rank does not match fragment type"); - if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) - return anchor->emitError("projected fragment layout has an invalid fragment count"); - if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) - return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); - return success(); -} - -FailureOr -getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) { - if (failed( - verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type"))) - return failure(); - return getPackedBatchTensorType(fragmentType, payloadFragmentCount); -} - -SmallVector, 4> -buildProjectedFragmentOffsetsByDim(ArrayRef> fragmentOffsets, size_t rank) { - SmallVector, 4> fragmentOffsetsByDim(rank); - for (ArrayRef offsets : fragmentOffsets) { - assert(offsets.size() == rank && "projected offset rank mismatch"); - for (size_t dim = 0; dim < rank; ++dim) - fragmentOffsetsByDim[dim].push_back(offsets[dim]); - } - return fragmentOffsetsByDim; -} - -LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) { - if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) - return failure(); - if (!descriptor.payloadType) - return anchor->emitError("projected transfer descriptor is missing payload type"); - if (descriptor.fragmentOffsets.empty()) - return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); - if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) - return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); - for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) - if (dimOffsets.size() != descriptor.fragmentOffsets.size()) - return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); - for (ArrayRef offsets : descriptor.fragmentOffsets) - if (offsets.size() != descriptor.layout.fragmentShape.size()) - return anchor->emitError("projected transfer offset rank does not match fragment rank"); - return success(); -} - -LogicalResult verifyProjectedSendDescriptor(Operation* anchor, - const ProjectedTransferDescriptor& descriptor, - const MessageVector& messages) { - if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) - return failure(); - if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) - return anchor->emitError("projected send descriptor metadata is inconsistent"); - return success(); -} - -LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) { - descriptor.fragmentOffsetsByDim = - buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); - - FailureOr payloadType = - getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); - if (failed(payloadType)) - return failure(); - if (descriptor.payloadType && descriptor.payloadType != *payloadType) - return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); - descriptor.payloadType = *payloadType; - - return verifyProjectedTransferDescriptor(anchor, descriptor); -} - static FailureOr evaluateProjectedOffsetValue(OpFoldResult value, Value laneArg, uint32_t lane, @@ -4833,73 +4375,9 @@ FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, Value runSlotIndex, Location loc); -FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc) { - assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); +} // namespace - SmallVector keys = flattenPackedScalarRunKeys(run); - if (keys.empty()) - return failure(); - FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); - if (failed(packedType)) - return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); - - SmallVector sourceLanes; - sourceLanes.reserve(keys.size()); - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return failure(); - sourceLanes.push_back(key.instance.laneStart); - } - - SmallVector resultIndices {run.resultIndex}; - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {init}, - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - keys.front().instance, - sourceLane, - resultIndices, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); - if (failed(produced) || produced->size() != 1) - return failure(); - - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - run.packed = loop->results.front(); - return run.packed; -} +namespace { LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, MaterializedClass& targetClass, @@ -5970,123 +5448,6 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt return true; } -LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { - if (state.pendingProjectedHostOutputFragments.empty()) - return success(); - - DenseMap> byOutput; - for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) - byOutput[fragment.originalOutput].push_back(&fragment); - - SmallVector outputs; - outputs.reserve(byOutput.size()); - - auto returnOp = dyn_cast(state.func.getBody().front().getTerminator()); - if (!returnOp) - return state.func.emitError("expected func.return terminator while finalizing projected host output fragments"); - - DenseSet seenOutputs; - for (Value returned : returnOp.getOperands()) { - if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second) - continue; - outputs.push_back(returned); - } - if (outputs.size() != byOutput.size()) - return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs"); - - for (Value originalOutput : outputs) { - if (isa_and_present(originalOutput.getDefiningOp())) { - return state.func.emitError( - "projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result"); - } - - auto resultType = dyn_cast(originalOutput.getType()); - if (!resultType || !resultType.hasStaticShape()) - return state.func.emitError("projected host output must have static ranked tensor type"); - - SmallVector& fragments = byOutput[originalOutput]; - llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, - const PendingProjectedHostOutputFragment* rhs) { - if (lhs->sourceClass != rhs->sourceClass) - return lhs->sourceClass < rhs->sourceClass; - if (lhs->publicationResultIndex != rhs->publicationResultIndex) - return lhs->publicationResultIndex < rhs->publicationResultIndex; - if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) - return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; - return std::lexicographical_compare(lhs->offsets.begin(), - lhs->offsets.end(), - rhs->offsets.begin(), - rhs->offsets.end()); - }); - - state.rewriter.setInsertionPoint(returnOp); - Location loc = fragments.front()->loc; - SmallVector blueprintOperands; - SmallVector fragmentOperandIndices; - SmallVector fragmentSourceOffsets; - SmallVector flatOffsets; - SmallVector flatSizes; - SmallVector flatStrides; - DenseMap operandIndicesByValue; - - for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { - if (fragmentRecord->sourceClass >= state.classes.size()) - return state.func.emitError("projected host output fragment references an invalid source class"); - - MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; - if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) { - return sourceClass.op->emitError("projected host output fragment references an invalid publication result") - << " sourceClass=" << sourceClass.id - << " resultIndex=" << fragmentRecord->publicationResultIndex - << " resultCount=" << sourceClass.op->getNumResults(); - } - - Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex); - - auto [operandIt, inserted] = - operandIndicesByValue.try_emplace(operand, static_cast(blueprintOperands.size())); - if (inserted) - blueprintOperands.push_back(operand); - fragmentOperandIndices.push_back(operandIt->second); - fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset); - llvm::append_range(flatOffsets, fragmentRecord->offsets); - llvm::append_range(flatSizes, fragmentRecord->sizes); - llvm::append_range(flatStrides, fragmentRecord->strides); - - auto operandType = dyn_cast(operand.getType()); - if (!operandType || !operandType.hasStaticShape()) - return state.func.emitError("projected host output assembly requires static ranked tensor operands"); - } - - if (blueprintOperands.empty()) - return state.func.emitError("missing projected host output fragments"); - - Value input = blueprintOperands.front(); - ValueRange extraFragments = ValueRange(blueprintOperands).drop_front(); - auto blueprint = spatial::SpatBlueprintOp::create( - state.rewriter, - loc, - resultType, - input, - extraFragments, - state.rewriter.getStringAttr("nchw"), - state.rewriter.getStringAttr("fragmented"), - state.rewriter.getDenseI64ArrayAttr(flatOffsets), - state.rewriter.getDenseI64ArrayAttr(flatSizes), - state.rewriter.getStringAttr("identity"), - state.rewriter.getStringAttr("fragment_assembly"), - state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices), - state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets), - state.rewriter.getDenseI64ArrayAttr(flatStrides), - state.rewriter.getStringAttr("disjoint"), - state.rewriter.getStringAttr("complete")); - - state.hostReplacements[originalOutput] = blueprint.getOutput(); - } - - return success(); -} - FailureOr resolveInputValue(MaterializerState& state, MaterializedClass& targetClass, Value input, @@ -8102,36 +7463,6 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, return success(); } -FailureOr createReceiveConcatLoop(MaterializerState& state, - MaterializedClass& targetClass, - RankedTensorType concatType, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one receive"); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); - return emitIndexedFragmentInsertLoop( - state, - targetClass, - init, - static_cast(messages.size()), - [&](Value index) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc); - return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - [&](Value index) -> FailureOr { - return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc); - }, - loc); -} - bool valueMayEvaluateToCore(Value value, int64_t coreId) { if (std::optional constant = getConstantIndexValue(value)) return *constant == coreId; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializedClassState.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializedClassState.hpp new file mode 100644 index 0000000..3c243d7 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializedClassState.hpp @@ -0,0 +1,252 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" + +#include + +#include "MaterializeMergeSchedule.hpp" +#include "MergeMessages.hpp" +#include "MergeScheduleKeys.hpp" +#include "ProjectedFragments.hpp" +#include "Scheduling/ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir::spatial { + +struct MaterializedClass { + ClassId id = 0; + llvm::SmallVector cpus; + mlir::Operation* op = nullptr; + mlir::Block* body = nullptr; + bool isBatch = false; + + llvm::DenseMap cpuToLane; + llvm::SmallVector weights; + llvm::SmallVector inputs; + llvm::SmallVector hostOutputs; + llvm::DenseMap publicationOutputToResultIndex; + llvm::DenseMap weightArgs; + llvm::DenseMap inputArgs; + llvm::DenseMap hostOutputToResultIndex; +}; + +struct PackedScalarRunSlot { + llvm::SmallVector keys; +}; + +enum class PackedScalarRunKind { + Materialized, + DeferredReceive, + DeferredLocalCompute +}; + +struct PackedScalarRunValue { + ClassId targetClass = 0; + mlir::Operation* sourceOp = nullptr; + size_t resultIndex = 0; + PackedScalarRunKind kind = PackedScalarRunKind::Materialized; + + mlir::Value packed; + + mlir::RankedTensorType fragmentType; + llvm::SmallVector slots; + MessageVector messages; +}; + +struct IndexedBatchRunValue { + ClassId targetClass = 0; + mlir::Operation* sourceOp = nullptr; + size_t resultIndex = 0; + mlir::Value packed; + mlir::RankedTensorType fragmentType; + llvm::SmallVector slots; + MessageVector messages; +}; + +struct LogicalSlotRange { + SlotId start = 0; + SlotId count = 0; +}; + +struct MaterializationRunSlot { + llvm::SmallVector peers; +}; + +using MaterializationRun = llvm::SmallVector; + +struct OutputDestinationGroup { + llvm::SmallVector resultIndices; + llvm::SmallVector destinationClasses; +}; + +struct BatchRunSendPlan { + size_t resultIndex = 0; + ClassId destinationClass = 0; + MessageVector messages; +}; + +enum class TensorDemandActionKind { + DestinationFanout, + SameClassIndexedFragment, + TerminalBlueprintPublication, + WholeTensorBarrier +}; + +enum class WholeTensorBarrierReason { + FunctionReturnWithoutBlueprint, + DenseLogicalConsumer +}; + +struct TensorDemandAction { + TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout; + std::optional destinationClass; + std::optional barrierReason; +}; + +struct RunOutputDemand { + size_t resultIndex = 0; + mlir::Value originalOutput; + mlir::RankedTensorType fragmentType; + llvm::SmallVector actions; +}; + +struct CompactRunPlan { + llvm::SmallVector outputs; +}; + +enum class BatchInputDemandKind { + LaneFragment, + ProjectedFragment, + WholeTensorBarrier +}; + +struct BatchInputDemand { + BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment; + std::optional wholeTensorProducer; +}; + +struct CloneIndexingContext { + std::optional runSlotIndex; + std::optional projectionSlotIndex; +}; + +struct MaterializerState; + +class AvailableValueStore { +public: + struct ExactBatchFragmentRecord { + ProducerKey key; + mlir::Value value; + }; + + void record(ProducerKey key, ClassId classId, mlir::Value value) { + exactValues[key][classId] = value; + + auto batch = mlir::dyn_cast_or_null(key.instance.op); + if (!batch || key.instance.laneCount == 0) + return; + + WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; + llvm::SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; + for (ExactBatchFragmentRecord& record : bucket) { + if (!(record.key == key)) + continue; + record.value = value; + return; + } + bucket.push_back({key, value}); + } + + void recordPackedRun(PackedScalarRunValue run) { + size_t runIndex = packedScalarRuns.size(); + packedScalarRuns.push_back(std::move(run)); + const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; + WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; + packedRunsByProducerResultClass[lookupKey].push_back(runIndex); + } + + void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } + + std::optional lookupExact(ProducerKey key, ClassId classId) const; + std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); + IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); + + llvm::ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = packedRunsByProducerResultClass.find(key); + if (it == packedRunsByProducerResultClass.end()) + return {}; + return it->second; + } + + llvm::ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = exactBatchFragmentsByProducerResultClass.find(key); + if (it == exactBatchFragmentsByProducerResultClass.end()) + return {}; + return it->second; + } + + PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } + +private: + std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); + + llvm::DenseMap, ProducerKeyInfo> exactValues; + llvm::SmallVector packedScalarRuns; + llvm::SmallVector indexedBatchRuns; + llvm::DenseMap, + WholeBatchAssemblyLookupKeyInfo> + exactBatchFragmentsByProducerResultClass; + llvm::DenseMap, WholeBatchAssemblyLookupKeyInfo> + packedRunsByProducerResultClass; +}; + +struct MaterializerState { + mlir::func::FuncOp func; + const MergeScheduleResult& schedule; + mlir::IRRewriter rewriter; + mlir::OperationFolder constantFolder; + int64_t& nextChannelId; + llvm::SmallVector classes; + llvm::DenseMap cpuToClass; + llvm::DenseMap> logicalInstancesByCpu; + llvm::DenseMap scheduledInstanceToLogicalSlots; + llvm::DenseMap logicalInstanceToScheduledChunk; + llvm::DenseSet materializedLogicalSlots; + + llvm::DenseMap, ProducerKeyInfo> producerDestClasses; + llvm::DenseMap, SameClassConsumerLookupKeyInfo> + sameClassConsumerIndex; + llvm::DenseMap + projectedInputMatches; + llvm::DenseSet nonProjectedInputs; + llvm::DenseMap liveExternalUseCache; + llvm::DenseMap> batchOutputFragmentTypesCache; + llvm::DenseMap, llvm::DenseMapInfo> + computeInstanceOutputsCache; + llvm::DenseMap, ProducerKeyInfo> + projectedTransfers; + llvm::DenseMap> + projectedExtractReplacements; + AvailableValueStore availableValues; + llvm::DenseMap hostReplacements; + llvm::DenseMap hostOutputOwners; + llvm::SmallVector pendingProjectedHostOutputFragments; + llvm::DenseSet oldComputeOps; + + MaterializerState(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) + : func(func), + schedule(schedule), + rewriter(func.getContext()), + constantFolder(func.getContext()), + nextChannelId(nextChannelId) {} +}; + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 6793437..4f50638 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -28,6 +28,7 @@ #include "Scheduling/ComputeGraph.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" @@ -43,16 +44,6 @@ using namespace onnx_mlir::compact_asm; using SpatCompute = spatial::SpatGraphCompute; using SpatComputeBatch = spatial::SpatGraphComputeBatch; -static std::optional getComputeCoreId(spatial::SpatScheduledCompute compute) { - if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { - auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id"); - if (failed(checkedCoreId)) - return std::nullopt; - return *checkedCoreId; - } - return std::nullopt; -} - bool isTrivialSerialMergeCandidate(SpatCompute compute) { if (!compute->hasOneUse()) return false; @@ -213,8 +204,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation()); SmallVector coreIds; - if (auto coreId = getComputeCoreId(spatCompute)) - coreIds.push_back(*coreId); + auto coreId = getOptionalScheduledCoreId(spatCompute, "merge compute core id"); + if (failed(coreId)) + return; + if (*coreId) + coreIds.push_back(**coreId); uint64_t computeId = totalComputeOps++; collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds}); uint64_t maxConcatOperands = 0; @@ -234,8 +228,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu uint64_t logicalCount = static_cast(batch.getLaneCount()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation()); SmallVector coreIds; - if (auto coreIdsAttr = batch->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) - llvm::append_range(coreIds, coreIdsAttr.asArrayRef()); + auto optionalCoreIds = getOptionalScheduledBatchCoreIds(batch, "merge compute_batch core id"); + if (failed(optionalCoreIds)) + return; + if (*optionalCoreIds) + coreIds = std::move(**optionalCoreIds); collectedData.push_back( {nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds}); totalComputeOps += 1; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeMessages.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeMessages.hpp new file mode 100644 index 0000000..b992723 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeMessages.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" + +namespace onnx_mlir::spatial { + +using CpuId = size_t; + +inline mlir::FailureOr getCheckedCoreId(mlir::Operation* anchor, CpuId cpu, llvm::StringRef fieldName) { + return pim::checkedI32(static_cast(cpu), anchor, fieldName); +} + +inline mlir::FailureOr> +getCheckedCoreIds(mlir::Operation* anchor, llvm::ArrayRef cpus, llvm::StringRef fieldName) { + llvm::SmallVector coreIds; + coreIds.reserve(cpus.size()); + for (CpuId cpu : cpus) { + auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); + if (mlir::failed(checkedCoreId)) + return mlir::failure(); + coreIds.push_back(*checkedCoreId); + } + return coreIds; +} + +struct MessageVector { + llvm::SmallVector channelIds; + llvm::SmallVector sourceCoreIds; + llvm::SmallVector targetCoreIds; + + size_t size() const { return channelIds.size(); } + bool empty() const { return channelIds.empty(); } + + mlir::LogicalResult verify(mlir::Operation* anchor) const { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return anchor->emitError("message metadata is inconsistent"); + return mlir::success(); + } + + void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { + channelIds.push_back(channelId); + sourceCoreIds.push_back(sourceCoreId); + targetCoreIds.push_back(targetCoreId); + } + + void append(llvm::ArrayRef channels, llvm::ArrayRef sources, llvm::ArrayRef targets) { + assert(channels.size() == sources.size() && "channel/source count mismatch"); + assert(channels.size() == targets.size() && "channel/target count mismatch"); + llvm::append_range(channelIds, channels); + llvm::append_range(sourceCoreIds, sources); + llvm::append_range(targetCoreIds, targets); + } + + MessageVector slice(size_t offset, size_t count) const { + MessageVector result; + result.append(llvm::ArrayRef(channelIds).slice(offset, count), + llvm::ArrayRef(sourceCoreIds).slice(offset, count), + llvm::ArrayRef(targetCoreIds).slice(offset, count)); + return result; + } +}; + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeScheduleKeys.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeScheduleKeys.hpp new file mode 100644 index 0000000..2fd1f2d --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeScheduleKeys.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include +#include + +#include "Scheduling/ComputeInstanceUtils.hpp" + +namespace onnx_mlir::spatial { + +using ClassId = size_t; +using SlotId = size_t; + +struct ProducerKey { + ComputeInstance instance; + size_t resultIndex = 0; + + bool operator==(const ProducerKey& other) const { + return instance == other.instance && resultIndex == other.resultIndex; + } +}; + +struct ProducerKeyInfo { + static ProducerKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProducerKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProducerKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); + } + + static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } +}; + +struct SameClassConsumerLookupKey { + mlir::Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const SameClassConsumerLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct SameClassConsumerLookupKeyInfo { + static SameClassConsumerLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static SameClassConsumerLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const SameClassConsumerLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), + key.resultIndex, + key.classId); + } + + static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { + return lhs == rhs; + } +}; + +struct WholeBatchAssemblyLookupKey { + mlir::Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const WholeBatchAssemblyLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct WholeBatchAssemblyLookupKeyInfo { + static WholeBatchAssemblyLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static WholeBatchAssemblyLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), + key.resultIndex, + key.classId); + } + + static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { + return lhs == rhs; + } +}; + +using ClassSlotKey = std::pair; + +struct ProjectedBatchInputKey { + mlir::Operation* consumerOp = nullptr; + unsigned inputIndex = 0; + + bool operator==(const ProjectedBatchInputKey& other) const { + return consumerOp == other.consumerOp && inputIndex == other.inputIndex; + } +}; + +struct ProjectedBatchInputKeyInfo { + static ProjectedBatchInputKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProjectedBatchInputKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProjectedBatchInputKey& key) { + return llvm::hash_combine(key.consumerOp, key.inputIndex); + } + + static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } +}; + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.cpp new file mode 100644 index 0000000..1aa5610 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.cpp @@ -0,0 +1,104 @@ +#include "ProjectedFragments.hpp" + +#include "mlir/IR/BuiltinTypes.h" + +namespace onnx_mlir::spatial { + +static mlir::FailureOr getPackedBatchTensorType(mlir::Type laneType, size_t laneCount) { + auto tensorType = mlir::dyn_cast(laneType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return mlir::failure(); + + llvm::SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(laneCount); + return mlir::RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef loopTripCounts) { + unsigned fragmentsPerLogicalSlot = 1; + for (int64_t tripCount : loopTripCounts) { + assert(tripCount > 0 && "projected loop trip counts must be positive"); + fragmentsPerLogicalSlot *= static_cast(tripCount); + } + return fragmentsPerLogicalSlot; +} + +mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout) { + if (!layout.fragmentType || layout.fragmentShape.empty()) + return anchor->emitError("projected fragment layout is missing fragment type metadata"); + if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) + return anchor->emitError("projected fragment layout rank does not match fragment type"); + if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) + return anchor->emitError("projected fragment layout has an invalid fragment count"); + if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) + return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); + return mlir::success(); +} + +mlir::FailureOr +getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount) { + auto packedType = getPackedBatchTensorType(fragmentType, payloadFragmentCount); + if (mlir::failed(packedType)) { + anchor->emitError("cannot create projected payload type"); + return mlir::failure(); + } + return *packedType; +} + +llvm::SmallVector, 4> +buildProjectedFragmentOffsetsByDim(llvm::ArrayRef> fragmentOffsets, size_t rank) { + llvm::SmallVector, 4> fragmentOffsetsByDim(rank); + for (llvm::ArrayRef offsets : fragmentOffsets) { + assert(offsets.size() == rank && "projected offset rank mismatch"); + for (size_t dim = 0; dim < rank; ++dim) + fragmentOffsetsByDim[dim].push_back(offsets[dim]); + } + return fragmentOffsetsByDim; +} + +mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor, + const ProjectedTransferDescriptor& descriptor) { + if (mlir::failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) + return mlir::failure(); + if (!descriptor.payloadType) + return anchor->emitError("projected transfer descriptor is missing payload type"); + if (descriptor.fragmentOffsets.empty()) + return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); + if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (llvm::ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + if (dimOffsets.size() != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (llvm::ArrayRef offsets : descriptor.fragmentOffsets) + if (offsets.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer offset rank does not match fragment rank"); + return mlir::success(); +} + +mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages) { + if (mlir::failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + return mlir::failure(); + if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected send descriptor metadata is inconsistent"); + return mlir::success(); +} + +mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor, + ProjectedTransferDescriptor& descriptor) { + descriptor.fragmentOffsetsByDim = + buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); + + auto payloadType = + getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); + if (mlir::failed(payloadType)) + return mlir::failure(); + if (descriptor.payloadType && descriptor.payloadType != *payloadType) + return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); + descriptor.payloadType = *payloadType; + + return verifyProjectedTransferDescriptor(anchor, descriptor); +} + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.hpp new file mode 100644 index 0000000..0a655b9 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/ProjectedFragments.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include + +#include "MergeMessages.hpp" +#include "MergeScheduleKeys.hpp" + +namespace onnx_mlir::spatial { + +struct ProjectedFragmentLayout { + mlir::RankedTensorType fragmentType; + llvm::SmallVector fragmentShape; + unsigned fragmentsPerLogicalSlot = 1; + unsigned payloadFragmentCount = 1; + llvm::SmallVector loopLowerBounds; + llvm::SmallVector loopSteps; + llvm::SmallVector loopTripCounts; +}; + +struct StaticProjectedLoopInfo { + mlir::BlockArgument iv; + int64_t lowerBound = 0; + int64_t step = 1; + int64_t tripCount = 1; +}; + +struct ProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + mlir::Operation* extractOp = nullptr; + ProjectedFragmentLayout layout; + mlir::RankedTensorType payloadType; + llvm::SmallVector, 16> fragmentOffsets; + llvm::SmallVector, 4> fragmentOffsetsByDim; +}; + +struct ProjectedExtractReplacement { + mlir::Value payload; + ProjectedFragmentLayout layout; +}; + +struct PendingProjectedHostOutputFragment { + mlir::Value originalOutput; + ClassId sourceClass = 0; + ProducerKey producerKey; + unsigned publicationResultIndex = 0; + int64_t sourceFragmentOrdinal = 0; + int64_t sourceElementOffset = 0; + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; + uint32_t sourceLane = 0; + mlir::Location loc; +}; + +struct AffineProjectedInputSliceMatch { + mlir::tensor::ExtractSliceOp extract; + mlir::RankedTensorType sourceType; + mlir::RankedTensorType fragmentType; + llvm::SmallVector fragmentShape; + llvm::SmallVector offsets; + llvm::SmallVector loops; +}; + +unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef loopTripCounts); +mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout); +mlir::FailureOr +getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount); +llvm::SmallVector, 4> +buildProjectedFragmentOffsetsByDim(llvm::ArrayRef> fragmentOffsets, size_t rank); +mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor, + const ProjectedTransferDescriptor& descriptor); +mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages); +mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor, + ProjectedTransferDescriptor& descriptor); + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/Utils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/Utils.hpp index 3737a44..917d31b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/Utils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/Utils.hpp @@ -12,7 +12,6 @@ #include #include -#include "src/Accelerators/PIM/Common/LabeledList.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using CPU = int; diff --git a/test/PIM/CMakeLists.txt b/test/PIM/CMakeLists.txt index 4af9680..12e6288 100644 --- a/test/PIM/CMakeLists.txt +++ b/test/PIM/CMakeLists.txt @@ -23,13 +23,6 @@ function(add_pim_unittest test_name) set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest) endfunction() -add_pim_unittest(LabeledListTest - LabeledListTest.cpp - - LINK_LIBS PRIVATE - OMPimCommon -) - add_pim_unittest(PimMemoryLivenessPlannerTest PimMemoryLivenessPlannerTest.cpp diff --git a/test/PIM/LabeledListTest.cpp b/test/PIM/LabeledListTest.cpp deleted file mode 100644 index 6d9ae99..0000000 --- a/test/PIM/LabeledListTest.cpp +++ /dev/null @@ -1,162 +0,0 @@ -#include -#include -#include -#include -#include - -#include "src/Accelerators/PIM/Common/LabeledList.hpp" - -using onnx_mlir::LabeledList; -using onnx_mlir::LabeledListNode; - -namespace { - -struct TestNode : public LabeledListNode { - explicit TestNode(int id) - : id(id) {} - - int id; -}; - -void assertOrder(LabeledList& list, std::initializer_list expectedOrder) { - auto expectedIt = expectedOrder.begin(); - for (auto& node : list) { - assert(expectedIt != expectedOrder.end()); - assert(node.id == *expectedIt); - ++expectedIt; - } - assert(expectedIt == expectedOrder.end()); -} - -void assertStrictlyIncreasingLabels(LabeledList& list) { - auto it = list.begin(); - if (it == list.end()) - return; - - auto previousLabel = it->getOrderLabel(); - ++it; - for (; it != list.end(); ++it) { - assert(previousLabel < it->getOrderLabel()); - previousLabel = it->getOrderLabel(); - } -} - -int testLabeledListBasicMutation() { - std::cout << "testLabeledListBasicMutation:" << std::endl; - - LabeledList list; - TestNode n1(1); - TestNode n2(2); - TestNode n3(3); - TestNode n4(4); - TestNode n5(5); - - assert(list.empty()); - assert(list.front() == nullptr); - assert(list.back() == nullptr); - assert(!list.contains(&n1)); - assert(LabeledList::previous(&n1) == nullptr); - assert(LabeledList::next(&n1) == nullptr); - - list.pushBack(&n1); - list.pushBack(&n3); - list.insertAfter(&n1, &n2); - list.pushFront(&n4); - list.insertBefore(nullptr, &n5); - - assert(list.size() == 5); - assert(list.front() == &n4); - assert(list.back() == &n5); - assert(list.contains(&n2)); - assertOrder(list, {4, 1, 2, 3, 5}); - assert(LabeledList::next(&n4) == &n1); - assert(LabeledList::previous(&n1) == &n4); - assert(LabeledList::next(&n5) == nullptr); - assert(list.comesBefore(&n1, &n3)); - assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3)); - - list.moveBefore(&n5, &n2); - assertOrder(list, {4, 1, 5, 2, 3}); - - list.moveAfter(&n4, &n3); - assertOrder(list, {1, 5, 2, 3, 4}); - - list.remove(&n2); - assert(!n2.isLinked()); - assert(!list.contains(&n2)); - assertOrder(list, {1, 5, 3, 4}); - - list.clear(); - assert(list.empty()); - assert(list.size() == 0); - assert(list.front() == nullptr); - assert(list.back() == nullptr); - assert(!n1.isLinked()); - assert(!n3.isLinked()); - assert(!n4.isLinked()); - assert(!n5.isLinked()); - - return 0; -} - -int testLabeledListRelabelingAndNoopMoves() { - std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl; - - constexpr int insertedNodeCount = 80; - LabeledList list; - TestNode head(0); - TestNode tail(999); - std::vector insertedNodes; - insertedNodes.reserve(insertedNodeCount); - for (int i = 0; i < insertedNodeCount; ++i) - insertedNodes.emplace_back(i + 1); - - list.pushBack(&head); - list.pushBack(&tail); - for (auto& node : insertedNodes) - list.insertAfter(&head, &node); - - assert(list.size() == insertedNodeCount + 2); - assert(list.front() == &head); - assert(list.back() == &tail); - assert(LabeledList::previous(&head) == nullptr); - assert(LabeledList::next(&tail) == nullptr); - assertStrictlyIncreasingLabels(list); - - auto* firstInserted = LabeledList::next(&head); - auto* secondInserted = LabeledList::next(firstInserted); - list.moveBefore(firstInserted, secondInserted); - list.moveAfter(&head, nullptr); - list.moveAfter(&tail, LabeledList::previous(&tail)); - - assert(list.front() == &head); - assert(list.back() == &tail); - assert(firstInserted == &insertedNodes.back()); - assert(secondInserted == &insertedNodes[insertedNodeCount - 2]); - assertStrictlyIncreasingLabels(list); - - int expectedId = insertedNodeCount; - auto it = std::next(list.begin()); - for (; it != list.end() && &*it != &tail; ++it, --expectedId) - assert(it->id == expectedId); - assert(expectedId == 0); - list.clear(); - - return 0; -} - -} // namespace - -int main(int argc, char* argv[]) { - (void) argc; - (void) argv; - - int failures = 0; - failures += testLabeledListBasicMutation(); - failures += testLabeledListRelabelingAndNoopMoves(); - if (failures != 0) { - std::cerr << failures << " test failures\n"; - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -}