refactor
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-29 14:00:10 +02:00
parent e8f09fd67f
commit f492400eda
37 changed files with 1407 additions and 1898 deletions
+1
View File
@@ -5,6 +5,7 @@ add_pim_library(OMPimCommon
IR/ConstantUtils.cpp IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp IR/EntryPointUtils.cpp
IR/IndexingUtils.cpp
IR/LoopUtils.cpp IR/LoopUtils.cpp
IR/ShapeUtils.cpp IR/ShapeUtils.cpp
IR/SubviewUtils.cpp IR/SubviewUtils.cpp
+60
View File
@@ -1,5 +1,6 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -9,6 +10,65 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
} }
mlir::FailureOr<std::optional<int32_t>>
getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) {
auto coreIdAttr = computeOp->getAttrOfType<mlir::IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
return std::optional<int32_t> {};
if (coreIdAttr.getInt() < 0) {
computeOp.emitOpError() << fieldName << " must be non-negative";
return mlir::failure();
}
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), computeOp, fieldName);
if (mlir::failed(checkedCoreId))
return mlir::failure();
return std::optional<int32_t> {*checkedCoreId};
}
mlir::FailureOr<int32_t> getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) {
auto coreId = getOptionalScheduledCoreId(computeOp, fieldName);
if (mlir::failed(coreId))
return mlir::failure();
if (!*coreId) {
computeOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdAttrName;
return mlir::failure();
}
return **coreId;
}
mlir::FailureOr<std::optional<llvm::SmallVector<int32_t>>>
getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) {
auto coreIdsAttr = computeBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return std::optional<llvm::SmallVector<int32_t>> {};
llvm::SmallVector<int32_t> coreIds;
coreIds.reserve(coreIdsAttr.size());
for (int32_t coreId : coreIdsAttr.asArrayRef()) {
if (coreId < 0) {
computeBatchOp.emitOpError() << fieldName << " values must be non-negative";
return mlir::failure();
}
auto checkedCoreId = pim::checkedI32(static_cast<int64_t>(coreId), computeBatchOp, fieldName);
if (mlir::failed(checkedCoreId))
return mlir::failure();
coreIds.push_back(*checkedCoreId);
}
return std::optional<llvm::SmallVector<int32_t>> {std::move(coreIds)};
}
mlir::FailureOr<llvm::SmallVector<int32_t>>
getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) {
auto coreIds = getOptionalScheduledBatchCoreIds(computeBatchOp, fieldName);
if (mlir::failed(coreIds))
return mlir::failure();
if (!*coreIds) {
computeBatchOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdsAttrName;
return mlir::failure();
}
return std::move(**coreIds);
}
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) { llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds; llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount); laneCoreIds.reserve(coreIds.size() / laneCount);
+14
View File
@@ -3,12 +3,26 @@
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <optional>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp); llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
mlir::FailureOr<std::optional<int32_t>>
getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName);
mlir::FailureOr<int32_t> getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName);
mlir::FailureOr<std::optional<llvm::SmallVector<int32_t>>>
getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName);
mlir::FailureOr<llvm::SmallVector<int32_t>>
getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName);
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane); llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex); bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
@@ -1,6 +1,6 @@
#include <algorithm> #include <algorithm>
#include "IndexingUtils.hpp" #include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp"
using namespace mlir; using namespace mlir;
+71
View File
@@ -1,6 +1,9 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include <functional>
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -163,4 +166,72 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
return true; return true;
} }
bool hasStaticPositiveShape(llvm::ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(mlir::RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(llvm::ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
llvm::SmallVector<int64_t> permuteShape(llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> permutation) {
llvm::SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
llvm::SmallVector<int64_t> invertPermutation(llvm::ArrayRef<int64_t> permutation) {
llvm::SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
mlir::FailureOr<llvm::SmallVector<int64_t>>
getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr, int64_t rank) {
llvm::SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return mlir::failure();
permutation.reserve(permAttr->size());
llvm::SmallVector<bool> seen(rank, false);
for (mlir::IntegerAttr attr : permAttr->getAsRange<mlir::IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return mlir::failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) {
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank) {
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef<int64_t> shape) {
llvm::SmallVector<mlir::OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
} // namespace onnx_mlir } // namespace onnx_mlir
+71
View File
@@ -2,15 +2,23 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cstddef> #include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
namespace onnx_mlir { namespace onnx_mlir {
using HSliceId = size_t;
using CoreId = size_t;
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape); llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> llvm::SmallVector<int64_t>
@@ -36,4 +44,67 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<int64_t> staticSizes, llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides); llvm::ArrayRef<int64_t> staticStrides);
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef<int64_t> shape);
} // namespace onnx_mlir } // namespace onnx_mlir
-315
View File
@@ -1,315 +0,0 @@
#pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ilist_node.h"
#include "llvm/ADT/simple_ilist.h"
#include <cassert>
#include <iterator>
#include <limits>
#include <type_traits>
namespace onnx_mlir {
template <typename NodeT>
class LabeledList;
template <typename NodeT>
class LabeledListNode : public llvm::ilist_node<NodeT> {
friend class LabeledList<NodeT>;
public:
using Label = uint64_t;
LabeledListNode() = default;
LabeledListNode(const LabeledListNode&) = delete;
LabeledListNode(LabeledListNode&&) = default;
LabeledListNode& operator=(LabeledListNode&&) = delete;
~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); }
bool isLinked() const { return owner_ != nullptr; }
Label getOrderLabel() const { return label; }
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt) { return lft.label < rgt.label; }
private:
const void* owner_ = nullptr;
Label label = 0;
};
template <typename NodeT>
class LabeledList {
using Label = typename NodeT::Label;
static constexpr Label kLowerSentinel = 0;
static constexpr Label kUpperSentinel = std::numeric_limits<Label>::max();
static constexpr Label kRelabelGap = 2;
public:
using List = llvm::simple_ilist<NodeT>;
using Iterator = typename List::iterator;
using RIterator = typename List::reverse_iterator;
using ConstIterator = typename List::const_iterator;
LabeledList() = default;
LabeledList(const LabeledList&) = delete;
LabeledList& operator=(const LabeledList&) = delete;
LabeledList(LabeledList&&) = delete;
LabeledList& operator=(LabeledList&&) = delete;
~LabeledList() { clear(); }
bool empty() const { return size_ == 0; }
size_t size() const { return size_; }
NodeT* front() { return empty() ? nullptr : &nodes_.front(); }
const NodeT* front() const { return empty() ? nullptr : &nodes_.front(); }
NodeT* back() { return empty() ? nullptr : &nodes_.back(); }
const NodeT* back() const { return empty() ? nullptr : &nodes_.back(); }
static NodeT* previous(NodeT* node) {
if (!node || !owner(node))
return nullptr;
auto* list = owner(node);
auto it = node->getIterator();
if (it == list->nodes_.begin())
return nullptr;
return &*std::prev(it);
}
static const NodeT* previous(const NodeT* node) {
if (!node || !owner(node))
return nullptr;
const auto* list = owner(node);
auto it = const_cast<NodeT*>(node)->getIterator();
if (it == list->nodes_.begin())
return nullptr;
return &*std::prev(it);
}
static NodeT* next(NodeT* node) {
if (!node || !owner(node))
return nullptr;
auto* list = owner(node);
auto it = std::next(node->getIterator());
if (it == list->nodes_.end())
return nullptr;
return &*it;
}
static const NodeT* next(const NodeT* node) {
if (!node || !owner(node))
return nullptr;
const auto* list = owner(node);
auto it = std::next(const_cast<NodeT*>(node)->getIterator());
if (it == list->nodes_.end())
return nullptr;
return &*it;
}
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
Label getOrderLabel(const NodeT* node) const {
assert(contains(node) && "node must belong to this list");
return node->label;
}
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
return lhs->label < rhs->label;
}
void pushFront(NodeT* node) { insertBefore(front(), node); }
void pushBack(NodeT* node) { insertBefore(nullptr, node); }
void insertBefore(NodeT* nextNode, NodeT* node) {
assert(node && "cannot insert a null node");
assert(!node->owner_ && "node is already linked");
assert(nextNode == nullptr || contains(nextNode));
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
nodes_.insert(nextIt, *node);
node->owner_ = this;
++size_;
assignLabel(getIteratorFor(node));
}
void insertAfter(NodeT* prevNode, NodeT* node) {
assert(prevNode == nullptr || contains(prevNode));
if (prevNode == nullptr)
insertBefore(front(), node);
else
insertBefore(next(prevNode), node);
}
void remove(NodeT* node) {
assert(contains(node) && "node must belong to this list");
nodes_.remove(*node);
node->owner_ = nullptr;
node->label = 0;
--size_;
}
void moveBefore(NodeT* node, NodeT* nextNode) {
assert(contains(node) && "node must belong to this list");
assert(nextNode == nullptr || contains(nextNode));
Iterator nodeIt = getIteratorFor(node);
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
if (nodeIt == nextIt || std::next(nodeIt) == nextIt)
return;
nodes_.splice(nextIt, nodes_, nodeIt);
assignLabel(getIteratorFor(node));
}
void moveAfter(NodeT* node, NodeT* prevNode) {
assert(contains(node) && "node must belong to this list");
assert(prevNode == nullptr || contains(prevNode));
Iterator nextIt = prevNode ? std::next(getIteratorFor(prevNode)) : nodes_.begin();
if (getIteratorFor(node) == nextIt)
return;
moveBefore(node, nextIt == nodes_.end() ? nullptr : &*nextIt);
}
void clear() {
while (!nodes_.empty()) {
NodeT* node = &nodes_.front();
node->owner_ = nullptr;
node->label = 0;
nodes_.remove(*node);
}
size_ = 0;
}
Iterator begin() { return nodes_.begin(); }
Iterator end() { return nodes_.end(); }
RIterator rbegin() { return nodes_.rbegin(); }
RIterator rend() { return nodes_.rend(); }
private:
static const LabeledList* owner(const NodeT* node) { return static_cast<const LabeledList*>(node->owner_); }
static LabeledList* owner(NodeT* node) { return static_cast<LabeledList*>(const_cast<void*>(node->owner_)); }
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
static Label labelGap(Label lower, Label upper) {
assert(lower < upper && "labels must be strictly ordered");
return upper - lower;
}
static bool hasMidpoint(Label lower, Label upper) { return labelGap(lower, upper) > 1; }
static bool hasRelabelSlack(Label lower, Label upper, size_t nodeCount) {
Label gap = labelGap(lower, upper);
return gap / static_cast<Label>(nodeCount + 1) >= kRelabelGap;
}
Iterator getIteratorFor(NodeT* node) { return node->getIterator(); }
ConstIterator getiteratorFor(const NodeT* node) const { return node->getIterator(); }
NodeT* previousNode(Iterator it) {
if (it == nodes_.begin())
return nullptr;
return &*std::prev(it);
}
const NodeT* previousNode(ConstIterator it) const {
if (it == nodes_.begin())
return nullptr;
return &*std::prev(it);
}
NodeT* nextNode(Iterator it) {
++it;
if (it == nodes_.end())
return nullptr;
return &*it;
}
const NodeT* nextNode(ConstIterator it) const {
++it;
if (it == nodes_.end())
return nullptr;
return &*it;
}
void assignLabel(Iterator it) {
Label lower = lowerLabel(previousNode(it));
Label upper = upperLabel(nextNode(it));
if (hasMidpoint(lower, upper)) {
(*it).label = lower + static_cast<Label>(labelGap(lower, upper) / 2);
return;
}
relabelAround(it);
}
void relabelAround(Iterator center) {
size_t targetCount = 1;
while (true) {
Iterator left = center;
Iterator right = center;
size_t actualCount = 1;
expandWindow(center, targetCount, left, right, actualCount);
Label lower = lowerLabel(previousNode(left));
Label upper = upperLabel(nextNode(right));
if (hasRelabelSlack(lower, upper, actualCount)) {
relabelWindow(left, actualCount, lower, upper);
return;
}
if (left == nodes_.begin() && nextNode(right) == nullptr) {
assert(hasRelabelSlack(lower, upper, actualCount) && "label space exhausted");
relabelWindow(left, actualCount, lower, upper);
return;
}
targetCount *= 2;
}
}
void expandWindow(Iterator center, size_t targetCount, Iterator& left, Iterator& right, size_t& actualCount) {
left = center;
right = center;
actualCount = 1;
while (actualCount < targetCount && (left != nodes_.begin() || nextNode(right) != nullptr)) {
if (left != nodes_.begin()) {
--left;
++actualCount;
if (actualCount == targetCount)
break;
}
if (nextNode(right) != nullptr) {
++right;
++actualCount;
}
}
}
void relabelWindow(Iterator left, size_t nodeCount, Label lower, Label upper) {
assert(nodeCount > 0 && "relabel window must not be empty");
Label step = labelGap(lower, upper) / static_cast<Label>(nodeCount + 1);
assert(step >= 1 && "relabel step must be positive");
Iterator it = left;
for (size_t index = 1; index <= nodeCount; ++index) {
(*it).label = lower + step * index;
++it;
}
}
List nodes_;
size_t size_ = 0;
};
} // namespace onnx_mlir
+1
View File
@@ -15,6 +15,7 @@
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" #include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" #include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
@@ -10,6 +10,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Post.cpp Patterns/Post.cpp
Patterns/GeneratedConversion.cpp Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/ConvGeometry.cpp
Patterns/Math/Elementwise.cpp Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp Patterns/Math/Gemm.cpp
Patterns/Math/MatMul.cpp Patterns/Math/MatMul.cpp
@@ -30,7 +31,7 @@ add_pim_library(OMONNXToSpatial
LowerSpatialPlansPass.cpp LowerSpatialPlansPass.cpp
Common/AttributeUtils.cpp Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp Common/MatrixProductLowering.cpp
Common/ShapeTilingUtils.cpp Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp Common/WeightMaterialization.cpp
@@ -2,7 +2,7 @@
#include "AttributeUtils.hpp" #include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp" #include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp" #include "MatrixProductLowering.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp" #include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -0,0 +1,48 @@
#include "MatrixProductLowering.hpp"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
Value createPaddedInputCompute(Value input,
RankedTensorType paddedInputType,
PatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
mlir::Value createZeroPaddedTensor(mlir::Value value,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
mlir::Value createPaddedInputCompute(mlir::Value input,
mlir::RankedTensorType paddedInputType,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -3,9 +3,6 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <functional>
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -15,73 +12,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
SmallVector<Value> sliceTensor( SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) { const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice); ArrayRef<long> shape = getTensorShape(tensorToSlice);
@@ -1,89 +1,15 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cassert> #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
namespace onnx_mlir { namespace onnx_mlir {
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of /// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements. /// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice, llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
@@ -26,6 +26,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override; ConversionPatternRewriter& rewriter) const override;
}; };
struct ConvLoweringState {
Value x;
Value w;
Value b;
RankedTensorType xType;
RankedTensorType wType;
RankedTensorType outType;
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t padHeightBegin;
int64_t padHeightEnd;
int64_t padWidthBegin;
int64_t padWidthEnd;
int64_t strideHeight;
int64_t strideWidth;
int64_t dilationHeight;
int64_t dilationWidth;
bool hasBias;
};
struct ConvGeometry {
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t k;
int64_t c;
int64_t p;
int64_t xbarSize;
int64_t pack;
uint64_t im2colElements;
bool hasBias;
bool isDepthwise;
};
struct ConvLoweringDecision { struct ConvLoweringDecision {
PimConvLoweringType strategy; PimConvLoweringType strategy;
std::string reason; std::string reason;
@@ -108,19 +56,6 @@ struct PreparedConvInput {
RankedTensorType type; RankedTensorType type;
}; };
struct RowInterval {
int64_t begin = 0;
int64_t end = 0;
};
struct ConvRowDemand {
RowInterval outputRows;
RowInterval neededInputRows;
RowInterval acquiredInputRows;
int64_t topHaloRows = 0;
int64_t bottomHaloRows = 0;
};
struct ConvStrategyEstimate { struct ConvStrategyEstimate {
uint64_t estimatedMvmCount = 0; uint64_t estimatedMvmCount = 0;
uint64_t estimatedReductionVAddCount = 0; uint64_t estimatedReductionVAddCount = 0;
@@ -291,9 +226,6 @@ static FailureOr<Value> createRowStripPackedRows(Value rows,
PatternRewriter& rewriter, PatternRewriter& rewriter,
Location loc); Location loc);
static bool
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) { static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
@@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
return estimate; return estimate;
} }
static ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
ConvGeometry geo {
state.batchSize,
state.numChannelsIn,
state.xHeight,
state.xWidth,
state.numChannelsOut,
state.wHeight,
state.wWidth,
state.outHeight,
state.outWidth,
state.group,
state.numChannelsInPerGroup,
state.numChannelsOutPerGroup,
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
state.numChannelsOutPerGroup,
state.batchSize * state.outHeight * state.outWidth,
static_cast<int64_t>(crossbarSize.getValue()),
1,
0,
state.hasBias,
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
};
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
return geo;
}
static std::string formatShape(ArrayRef<int64_t> dims) { static std::string formatShape(ArrayRef<int64_t> dims) {
std::string text; std::string text;
llvm::raw_string_ostream os(text); llvm::raw_string_ostream os(text);
@@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user,
return std::nullopt; return std::nullopt;
} }
static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows,
int64_t inputHeight,
int64_t kernelH,
int64_t strideH,
int64_t dilationH,
int64_t padTop) {
const int64_t rawBegin = outputRows.begin * strideH - padTop;
const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1;
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(inputHeight, rawEnd)};
}
static bool covers(RowInterval acquired, RowInterval needed) { static bool covers(RowInterval acquired, RowInterval needed) {
return acquired.begin <= needed.begin && acquired.end >= needed.end; return acquired.begin <= needed.begin && acquired.end >= needed.end;
} }
static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
const int64_t rawEnd =
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
RowInterval neededInputRows = computeConvInputRowsForOutputRows(
outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin);
ConvRowDemand demand;
demand.outputRows = outputRows;
demand.neededInputRows = neededInputRows;
demand.acquiredInputRows = neededInputRows;
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
return demand;
}
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) { static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
if (state.batchSize != 1) { if (state.batchSize != 1) {
failureReason = "unsupported_batch"; failureReason = "unsupported_batch";
@@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp,
rewriteConvLoweringReport(reportEntries); rewriteConvLoweringReport(reportEntries);
} }
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
}
return std::max<uint64_t>(1, chunkPositions);
}
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) { static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType()); auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1) if (biasType.getRank() != 1)
@@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location
}); });
} }
static bool
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
}
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) { static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
assert(value > 0 && "expected positive value"); assert(value > 0 && "expected positive value");
limit = std::min(value, limit); limit = std::min(value, limit);
@@ -0,0 +1,77 @@
#include "ConvGeometry.hpp"
#include <algorithm>
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
}
ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
ConvGeometry geo {
state.batchSize,
state.numChannelsIn,
state.xHeight,
state.xWidth,
state.numChannelsOut,
state.wHeight,
state.wWidth,
state.outHeight,
state.outWidth,
state.group,
state.numChannelsInPerGroup,
state.numChannelsOutPerGroup,
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
state.numChannelsOutPerGroup,
state.batchSize * state.outHeight * state.outWidth,
static_cast<int64_t>(crossbarSize.getValue()),
1,
0,
state.hasBias,
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
};
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
return geo;
}
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
}
return std::max<uint64_t>(1, chunkPositions);
}
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state) {
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
const int64_t rawEnd =
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(state.xHeight, rawEnd)};
}
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
ConvRowDemand demand;
demand.outputRows = outputRows;
demand.neededInputRows = computeConvInputRowsForOutputRows(outputRows, state);
demand.acquiredInputRows = demand.neededInputRows;
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
const int64_t rawEnd =
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
demand.acquiredInputRows = demand.neededInputRows;
return demand;
}
} // namespace onnx_mlir
@@ -0,0 +1,86 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include <cstdint>
namespace onnx_mlir {
struct ConvLoweringState {
mlir::Value x;
mlir::Value w;
mlir::Value b;
mlir::RankedTensorType xType;
mlir::RankedTensorType wType;
mlir::RankedTensorType outType;
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t padHeightBegin;
int64_t padHeightEnd;
int64_t padWidthBegin;
int64_t padWidthEnd;
int64_t strideHeight;
int64_t strideWidth;
int64_t dilationHeight;
int64_t dilationWidth;
bool hasBias;
};
struct ConvGeometry {
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t k;
int64_t c;
int64_t p;
int64_t xbarSize;
int64_t pack;
uint64_t im2colElements;
bool hasBias;
bool isDepthwise;
};
struct RowInterval {
int64_t begin = 0;
int64_t end = 0;
};
struct ConvRowDemand {
RowInterval outputRows;
RowInterval neededInputRows;
RowInterval acquiredInputRows;
int64_t topHaloRows = 0;
int64_t bottomHaloRows = 0;
};
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
} // namespace onnx_mlir
@@ -87,28 +87,6 @@ static Value createGemmBatchHOffset(Value lane,
rewriter.getInsertionBlock()->getParentOp()); rewriter.getInsertionBlock()->getParentOp());
} }
static Value
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> materializePaddedConstantMatrix(Value value, static FailureOr<Value> materializePaddedConstantMatrix(Value value,
RankedTensorType resultType, RankedTensorType resultType,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
@@ -232,22 +210,6 @@ static Value extractATile(
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult(); return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
} }
static Value createPaddedInputCompute(Value input,
RankedTensorType paddedInputType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a, static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
Value b, Value b,
RankedTensorType aType, RankedTensorType aType,
@@ -255,42 +255,6 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
return createONNXTranspose(resultType, {0, 2, 1}); return createONNXTranspose(resultType, {0, 2, 1});
} }
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value createPaddedBatchedInputCompute(Value input,
RankedTensorType paddedInputType,
PatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
static FailureOr<Value> materializePaddedBatchedWeight(Value value, static FailureOr<Value> materializePaddedBatchedWeight(Value value,
ArrayRef<int64_t> sourceBatchShape, ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> targetBatchShape, ArrayRef<int64_t> targetBatchShape,
@@ -1055,7 +1019,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
auto paddedRhs = auto paddedRhs =
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter); materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
if (succeeded(paddedRhs)) { if (succeeded(paddedRhs)) {
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc); Value paddedLhs = createPaddedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices; const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())}, auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
shapeInfo->outType.getElementType()); shapeInfo->outType.getElementType());
@@ -29,100 +29,6 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
}); });
} }
static bool isMaterializableExternalTensorOp(Operation* op) {
return isa<spatial::SpatChannelReceiveOp,
spatial::SpatExtractRowsOp,
tensor::ExtractSliceOp,
tensor::ExpandShapeOp,
tensor::CollapseShapeOp>(op);
}
//TODO REMOVE THIS UGLY FIX
//TODO: Remove this helper once compute_batch external tensor captures are
// fixed at the producer side.
//
// This function is a temporary SpatialToPim repair path. It clones selected
// external tensor producers, such as channel_receive and tensor view/slice ops,
// into the new pim.core_batch body when the old spat.compute_batch body refers
// to tensor values defined outside the batch.
//
// The real invariant should be stronger:
//
// A spat.compute_batch body must not capture external tensor values.
// Every tensor used inside the body must be either:
// - a compute_batch block argument,
// - defined inside the compute_batch body,
// - or a legal constant-like value.
//
// If this invariant is violated, the responsible producer, most likely merge
// schedule materialization, should emit verifier-clean Spatial IR instead of
// relying on SpatialToPim to clone external producer chains later.
//
// After that producer-side fix:
// 1. remove isMaterializableExternalTensorOp,
// 2. remove materializeExternalTensorValue,
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
// tensor operand,
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
// before SpatialToPim.
//
// Be careful not to replace every external tensor capture with a normal
// compute_batch input blindly: host-backed tensors and explicit inter-core
// communication have different semantics. In particular, channel_receive-like
// values should be materialized through the communication model, not silently
// treated as host inputs.
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
Location loc,
Block& oldBlock,
Value value,
IRMapping& mapper) {
if (mapper.contains(value))
return mapper.lookup(value);
if (!isa<TensorType>(value.getType()))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
return failure();
if (definingOp->getBlock() == &oldBlock)
return failure();
if (!isMaterializableExternalTensorOp(definingOp))
return failure();
for (Value operand : definingOp->getOperands()) {
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
if (succeeded(materializedOperand))
mapper.map(operand, *materializedOperand);
}
Operation* cloned = rewriter.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
if (failed(checkedCoreId))
return failure();
coreIds.push_back(*checkedCoreId);
++fallbackCoreId;
}
return coreIds;
}
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) { static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse()) if (!result.hasOneUse())
return failure(); return failure();
@@ -386,7 +292,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
"resultful compute_batch lowering currently requires a spat.in_parallel terminator"); "resultful compute_batch lowering currently requires a spat.in_parallel terminator");
} }
auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); auto coreIds = getRequiredScheduledBatchCoreIds(computeBatchOp, "spatial compute_batch core id");
if (failed(coreIds)) if (failed(coreIds))
return failure(); return failure();
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
@@ -638,9 +544,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>()) if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue; continue;
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
continue;
InFlightDiagnostic diagnostic = InFlightDiagnostic diagnostic =
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering"); computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex; diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
@@ -9,6 +9,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
@@ -141,17 +142,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
} }
} }
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeOp, "fallback spatial compute core id");
if (failed(checkedCoreId))
return failure();
++fallbackCoreId;
return *checkedCoreId;
}
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp, static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain, SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) { bool requireReturnUse = true) {
@@ -311,7 +301,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCom
if (!computeOp.getWeights().empty()) if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId); auto checkedCoreId = getRequiredScheduledCoreId(computeOp, "spatial compute core id");
if (failed(checkedCoreId)) if (failed(checkedCoreId))
return failure(); return failure();
auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id"); auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id");
@@ -44,121 +44,29 @@ using namespace pim;
namespace onnx_mlir { namespace onnx_mlir {
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
continue;
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
return globalOp;
}
std::string nameStem;
llvm::raw_string_ostream nameStream(nameStem);
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
nameStream.flush();
std::string symbolName = nameStem;
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
symbolName = (nameStem + "_" + Twine(suffix++)).str();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(symbolName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
zeroAttr,
rewriter.getUnitAttr(),
IntegerAttr {});
}
static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
OperationFolder& constantFolder) {
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size");
if (failed(byteSize))
return failure();
auto sizeAttr =
pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr)
.getOutput();
}
static bool isHostBackedMemRefValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return isa<memref::GetGlobalOp>(definingOp);
}
return false;
}
static bool isHostBackedTensorValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return false;
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
extractSliceOp.getMixedOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
return false;
}
value = extractSliceOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
return isHostBackedMemRefValue(toTensorOp.getBuffer());
return false;
}
return false;
}
static FailureOr<Value> static FailureOr<Value>
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { createZeroPaddedTensor(IRRewriter& rewriter, Location loc, Value value, RankedTensorType resultType) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
auto vectorType = cast<RankedTensorType>(vector.getType()); auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape(); ArrayRef<int64_t> shape = vectorType.getShape();
assert(isHVectorShape(shape) && "expected a horizontal vector"); assert(isHVectorShape(shape) && "expected a horizontal vector");
@@ -169,26 +77,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto paddedType = RankedTensorType::get( auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); {shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); return createZeroPaddedTensor(rewriter, loc, vector, paddedType);
if (failed(zeroed))
return failure();
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0);
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size");
if (failed(byteSize))
return failure();
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
if (failed(sizeAttr))
return failure();
if (isHostBackedTensorValue(vector)) {
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
} }
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
coreId = 0;
outputTensors.clear(); outputTensors.clear();
operationsToRemove.clear(); operationsToRemove.clear();
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -362,7 +254,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
bool hasFailure = false; bool hasFailure = false;
funcOp.walk([&](PimVMMOp vmmOp) { funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType()); auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
@@ -371,7 +262,7 @@ LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func:
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar"); assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp); rewriter.setInsertionPoint(vmmOp);
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
if (failed(paddedInput)) { if (failed(paddedInput)) {
hasFailure = true; hasFailure = true;
return WalkResult::interrupt(); return WalkResult::interrupt();
@@ -36,7 +36,6 @@ private:
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>; using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
llvm::SmallVector<OutputTensorFactory> outputTensors; llvm::SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
llvm::SmallVector<mlir::Operation*> operationsToRemove; llvm::SmallVector<mlir::Operation*> operationsToRemove;
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
+2
View File
@@ -8,7 +8,9 @@ add_pim_library(SpatialOps
SpatialOpsCanonicalization.cpp SpatialOpsCanonicalization.cpp
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp ${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/HostOutputFinalization.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/ProjectedFragments.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
@@ -0,0 +1,134 @@
#include "HostOutputFinalization.hpp"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "MaterializedClassState.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir::spatial {
LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
if (state.pendingProjectedHostOutputFragments.empty())
return success();
DenseMap<Value, SmallVector<PendingProjectedHostOutputFragment*, 16>> byOutput;
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments)
byOutput[fragment.originalOutput].push_back(&fragment);
SmallVector<Value, 8> outputs;
outputs.reserve(byOutput.size());
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
if (!returnOp)
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
DenseSet<Value> seenOutputs;
for (Value returned : returnOp.getOperands()) {
if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second)
continue;
outputs.push_back(returned);
}
if (outputs.size() != byOutput.size())
return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs");
for (Value originalOutput : outputs) {
if (isa_and_present<SpatScheduledCompute, SpatScheduledComputeBatch>(originalOutput.getDefiningOp())) {
return state.func.emitError(
"projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result");
}
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
return state.func.emitError("projected host output must have static ranked tensor type");
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
const PendingProjectedHostOutputFragment* rhs) {
if (lhs->sourceClass != rhs->sourceClass)
return lhs->sourceClass < rhs->sourceClass;
if (lhs->publicationResultIndex != rhs->publicationResultIndex)
return lhs->publicationResultIndex < rhs->publicationResultIndex;
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
return std::lexicographical_compare(lhs->offsets.begin(),
lhs->offsets.end(),
rhs->offsets.begin(),
rhs->offsets.end());
});
state.rewriter.setInsertionPoint(returnOp);
Location loc = fragments.front()->loc;
SmallVector<Value, 16> blueprintOperands;
SmallVector<int64_t, 16> fragmentOperandIndices;
SmallVector<int64_t, 16> fragmentSourceOffsets;
SmallVector<int64_t, 64> flatOffsets;
SmallVector<int64_t, 64> flatSizes;
SmallVector<int64_t, 64> flatStrides;
DenseMap<Value, int64_t> operandIndicesByValue;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
if (fragmentRecord->sourceClass >= state.classes.size())
return state.func.emitError("projected host output fragment references an invalid source class");
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) {
return sourceClass.op->emitError("projected host output fragment references an invalid publication result")
<< " sourceClass=" << sourceClass.id
<< " resultIndex=" << fragmentRecord->publicationResultIndex
<< " resultCount=" << sourceClass.op->getNumResults();
}
Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex);
auto [operandIt, inserted] =
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(blueprintOperands.size()));
if (inserted)
blueprintOperands.push_back(operand);
fragmentOperandIndices.push_back(operandIt->second);
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
llvm::append_range(flatOffsets, fragmentRecord->offsets);
llvm::append_range(flatSizes, fragmentRecord->sizes);
llvm::append_range(flatStrides, fragmentRecord->strides);
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
if (!operandType || !operandType.hasStaticShape())
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
}
if (blueprintOperands.empty())
return state.func.emitError("missing projected host output fragments");
Value input = blueprintOperands.front();
ValueRange extraFragments = ValueRange(blueprintOperands).drop_front();
auto blueprint = SpatBlueprintOp::create(
state.rewriter,
loc,
resultType,
input,
extraFragments,
state.rewriter.getStringAttr("nchw"),
state.rewriter.getStringAttr("fragmented"),
state.rewriter.getDenseI64ArrayAttr(flatOffsets),
state.rewriter.getDenseI64ArrayAttr(flatSizes),
state.rewriter.getStringAttr("identity"),
state.rewriter.getStringAttr("fragment_assembly"),
state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices),
state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets),
state.rewriter.getDenseI64ArrayAttr(flatStrides),
state.rewriter.getStringAttr("disjoint"),
state.rewriter.getStringAttr("complete"));
state.hostReplacements[originalOutput] = blueprint.getOutput();
}
return success();
}
} // namespace onnx_mlir::spatial
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir::spatial {
struct MaterializerState;
mlir::LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state);
} // namespace onnx_mlir::spatial
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,252 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
#include "MaterializeMergeSchedule.hpp"
#include "MergeMessages.hpp"
#include "MergeScheduleKeys.hpp"
#include "ProjectedFragments.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir::spatial {
struct MaterializedClass {
ClassId id = 0;
llvm::SmallVector<CpuId, 8> cpus;
mlir::Operation* op = nullptr;
mlir::Block* body = nullptr;
bool isBatch = false;
llvm::DenseMap<CpuId, unsigned> cpuToLane;
llvm::SmallVector<mlir::Value, 8> weights;
llvm::SmallVector<mlir::Value, 8> inputs;
llvm::SmallVector<mlir::Value, 4> hostOutputs;
llvm::DenseMap<mlir::Value, unsigned> publicationOutputToResultIndex;
llvm::DenseMap<mlir::Value, mlir::BlockArgument> weightArgs;
llvm::DenseMap<mlir::Value, mlir::BlockArgument> inputArgs;
llvm::DenseMap<mlir::Value, unsigned> hostOutputToResultIndex;
};
struct PackedScalarRunSlot {
llvm::SmallVector<ProducerKey, 8> keys;
};
enum class PackedScalarRunKind {
Materialized,
DeferredReceive,
DeferredLocalCompute
};
struct PackedScalarRunValue {
ClassId targetClass = 0;
mlir::Operation* sourceOp = nullptr;
size_t resultIndex = 0;
PackedScalarRunKind kind = PackedScalarRunKind::Materialized;
mlir::Value packed;
mlir::RankedTensorType fragmentType;
llvm::SmallVector<PackedScalarRunSlot, 8> slots;
MessageVector messages;
};
struct IndexedBatchRunValue {
ClassId targetClass = 0;
mlir::Operation* sourceOp = nullptr;
size_t resultIndex = 0;
mlir::Value packed;
mlir::RankedTensorType fragmentType;
llvm::SmallVector<PackedScalarRunSlot, 8> slots;
MessageVector messages;
};
struct LogicalSlotRange {
SlotId start = 0;
SlotId count = 0;
};
struct MaterializationRunSlot {
llvm::SmallVector<ComputeInstance, 8> peers;
};
using MaterializationRun = llvm::SmallVector<MaterializationRunSlot, 8>;
struct OutputDestinationGroup {
llvm::SmallVector<size_t, 4> resultIndices;
llvm::SmallVector<ClassId, 4> destinationClasses;
};
struct BatchRunSendPlan {
size_t resultIndex = 0;
ClassId destinationClass = 0;
MessageVector messages;
};
enum class TensorDemandActionKind {
DestinationFanout,
SameClassIndexedFragment,
TerminalBlueprintPublication,
WholeTensorBarrier
};
enum class WholeTensorBarrierReason {
FunctionReturnWithoutBlueprint,
DenseLogicalConsumer
};
struct TensorDemandAction {
TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout;
std::optional<ClassId> destinationClass;
std::optional<WholeTensorBarrierReason> barrierReason;
};
struct RunOutputDemand {
size_t resultIndex = 0;
mlir::Value originalOutput;
mlir::RankedTensorType fragmentType;
llvm::SmallVector<TensorDemandAction, 4> actions;
};
struct CompactRunPlan {
llvm::SmallVector<RunOutputDemand, 4> outputs;
};
enum class BatchInputDemandKind {
LaneFragment,
ProjectedFragment,
WholeTensorBarrier
};
struct BatchInputDemand {
BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment;
std::optional<ProducerKey> wholeTensorProducer;
};
struct CloneIndexingContext {
std::optional<mlir::Value> runSlotIndex;
std::optional<mlir::Value> projectionSlotIndex;
};
struct MaterializerState;
class AvailableValueStore {
public:
struct ExactBatchFragmentRecord {
ProducerKey key;
mlir::Value value;
};
void record(ProducerKey key, ClassId classId, mlir::Value value) {
exactValues[key][classId] = value;
auto batch = mlir::dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
if (!batch || key.instance.laneCount == 0)
return;
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
llvm::SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
for (ExactBatchFragmentRecord& record : bucket) {
if (!(record.key == key))
continue;
record.value = value;
return;
}
bucket.push_back({key, value});
}
void recordPackedRun(PackedScalarRunValue run) {
size_t runIndex = packedScalarRuns.size();
packedScalarRuns.push_back(std::move(run));
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
}
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
std::optional<mlir::Value> lookupExact(ProducerKey key, ClassId classId) const;
std::optional<mlir::Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
llvm::ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
auto it = packedRunsByProducerResultClass.find(key);
if (it == packedRunsByProducerResultClass.end())
return {};
return it->second;
}
llvm::ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
auto it = exactBatchFragmentsByProducerResultClass.find(key);
if (it == exactBatchFragmentsByProducerResultClass.end())
return {};
return it->second;
}
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
private:
std::optional<mlir::Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
llvm::DenseMap<ProducerKey, llvm::DenseMap<ClassId, mlir::Value>, ProducerKeyInfo> exactValues;
llvm::SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
llvm::SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
llvm::DenseMap<WholeBatchAssemblyLookupKey,
llvm::SmallVector<ExactBatchFragmentRecord, 16>,
WholeBatchAssemblyLookupKeyInfo>
exactBatchFragmentsByProducerResultClass;
llvm::DenseMap<WholeBatchAssemblyLookupKey, llvm::SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
packedRunsByProducerResultClass;
};
struct MaterializerState {
mlir::func::FuncOp func;
const MergeScheduleResult& schedule;
mlir::IRRewriter rewriter;
mlir::OperationFolder constantFolder;
int64_t& nextChannelId;
llvm::SmallVector<MaterializedClass, 8> classes;
llvm::DenseMap<CpuId, ClassId> cpuToClass;
llvm::DenseMap<CpuId, llvm::SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
llvm::DenseMap<ComputeInstance, LogicalSlotRange> scheduledInstanceToLogicalSlots;
llvm::DenseMap<ComputeInstance, ComputeInstance> logicalInstanceToScheduledChunk;
llvm::DenseSet<ClassSlotKey> materializedLogicalSlots;
llvm::DenseMap<ProducerKey, llvm::SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
llvm::DenseMap<SameClassConsumerLookupKey, llvm::SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
sameClassConsumerIndex;
llvm::DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo>
projectedInputMatches;
llvm::DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
llvm::DenseMap<mlir::Value, bool> liveExternalUseCache;
llvm::DenseMap<mlir::Operation*, llvm::SmallVector<mlir::Type, 4>> batchOutputFragmentTypesCache;
llvm::DenseMap<ComputeInstance, llvm::SmallVector<mlir::Value, 4>, llvm::DenseMapInfo<ComputeInstance>>
computeInstanceOutputsCache;
llvm::DenseMap<ProducerKey, llvm::DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo>
projectedTransfers;
llvm::DenseMap<mlir::Operation*, llvm::DenseMap<ClassId, ProjectedExtractReplacement>>
projectedExtractReplacements;
AvailableValueStore availableValues;
llvm::DenseMap<mlir::Value, mlir::Value> hostReplacements;
llvm::DenseMap<mlir::Value, ClassId> hostOutputOwners;
llvm::SmallVector<PendingProjectedHostOutputFragment, 32> pendingProjectedHostOutputFragments;
llvm::DenseSet<mlir::Operation*> oldComputeOps;
MaterializerState(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
: func(func),
schedule(schedule),
rewriter(func.getContext()),
constantFolder(func.getContext()),
nextChannelId(nextChannelId) {}
};
} // namespace onnx_mlir::spatial
@@ -28,6 +28,7 @@
#include "Scheduling/ComputeGraph.hpp" #include "Scheduling/ComputeGraph.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/ComputeInstanceUtils.hpp"
#include "Scheduling/MergeSchedulingAnalysis.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
@@ -43,16 +44,6 @@ using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatGraphCompute; using SpatCompute = spatial::SpatGraphCompute;
using SpatComputeBatch = spatial::SpatGraphComputeBatch; using SpatComputeBatch = spatial::SpatGraphComputeBatch;
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
if (failed(checkedCoreId))
return std::nullopt;
return *checkedCoreId;
}
return std::nullopt;
}
bool isTrivialSerialMergeCandidate(SpatCompute compute) { bool isTrivialSerialMergeCandidate(SpatCompute compute) {
if (!compute->hasOneUse()) if (!compute->hasOneUse())
return false; return false;
@@ -213,8 +204,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody()); uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute)) auto coreId = getOptionalScheduledCoreId(spatCompute, "merge compute core id");
coreIds.push_back(*coreId); if (failed(coreId))
return;
if (*coreId)
coreIds.push_back(**coreId);
uint64_t computeId = totalComputeOps++; uint64_t computeId = totalComputeOps++;
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds}); collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t maxConcatOperands = 0; uint64_t maxConcatOperands = 0;
@@ -234,8 +228,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount()); uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) auto optionalCoreIds = getOptionalScheduledBatchCoreIds(batch, "merge compute_batch core id");
llvm::append_range(coreIds, coreIdsAttr.asArrayRef()); if (failed(optionalCoreIds))
return;
if (*optionalCoreIds)
coreIds = std::move(**optionalCoreIds);
collectedData.push_back( collectedData.push_back(
{nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds}); {nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
totalComputeOps += 1; totalComputeOps += 1;
@@ -0,0 +1,67 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
namespace onnx_mlir::spatial {
using CpuId = size_t;
inline mlir::FailureOr<int32_t> getCheckedCoreId(mlir::Operation* anchor, CpuId cpu, llvm::StringRef fieldName) {
return pim::checkedI32(static_cast<uint64_t>(cpu), anchor, fieldName);
}
inline mlir::FailureOr<llvm::SmallVector<int32_t, 8>>
getCheckedCoreIds(mlir::Operation* anchor, llvm::ArrayRef<CpuId> cpus, llvm::StringRef fieldName) {
llvm::SmallVector<int32_t, 8> coreIds;
coreIds.reserve(cpus.size());
for (CpuId cpu : cpus) {
auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName);
if (mlir::failed(checkedCoreId))
return mlir::failure();
coreIds.push_back(*checkedCoreId);
}
return coreIds;
}
struct MessageVector {
llvm::SmallVector<int64_t, 16> channelIds;
llvm::SmallVector<int32_t, 16> sourceCoreIds;
llvm::SmallVector<int32_t, 16> targetCoreIds;
size_t size() const { return channelIds.size(); }
bool empty() const { return channelIds.empty(); }
mlir::LogicalResult verify(mlir::Operation* anchor) const {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return anchor->emitError("message metadata is inconsistent");
return mlir::success();
}
void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) {
channelIds.push_back(channelId);
sourceCoreIds.push_back(sourceCoreId);
targetCoreIds.push_back(targetCoreId);
}
void append(llvm::ArrayRef<int64_t> channels, llvm::ArrayRef<int32_t> sources, llvm::ArrayRef<int32_t> targets) {
assert(channels.size() == sources.size() && "channel/source count mismatch");
assert(channels.size() == targets.size() && "channel/target count mismatch");
llvm::append_range(channelIds, channels);
llvm::append_range(sourceCoreIds, sources);
llvm::append_range(targetCoreIds, targets);
}
MessageVector slice(size_t offset, size_t count) const {
MessageVector result;
result.append(llvm::ArrayRef<int64_t>(channelIds).slice(offset, count),
llvm::ArrayRef<int32_t>(sourceCoreIds).slice(offset, count),
llvm::ArrayRef<int32_t>(targetCoreIds).slice(offset, count));
return result;
}
};
} // namespace onnx_mlir::spatial
@@ -0,0 +1,134 @@
#pragma once
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <utility>
#include "Scheduling/ComputeInstanceUtils.hpp"
namespace onnx_mlir::spatial {
using ClassId = size_t;
using SlotId = size_t;
struct ProducerKey {
ComputeInstance instance;
size_t resultIndex = 0;
bool operator==(const ProducerKey& other) const {
return instance == other.instance && resultIndex == other.resultIndex;
}
};
struct ProducerKeyInfo {
static ProducerKey getEmptyKey() {
return {llvm::DenseMapInfo<ComputeInstance>::getEmptyKey(), std::numeric_limits<size_t>::max()};
}
static ProducerKey getTombstoneKey() {
return {llvm::DenseMapInfo<ComputeInstance>::getTombstoneKey(), std::numeric_limits<size_t>::max()};
}
static unsigned getHashValue(const ProducerKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<ComputeInstance>::getHashValue(key.instance), key.resultIndex);
}
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
};
struct SameClassConsumerLookupKey {
mlir::Operation* sourceOp = nullptr;
size_t resultIndex = 0;
ClassId classId = 0;
bool operator==(const SameClassConsumerLookupKey& other) const {
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
}
};
struct SameClassConsumerLookupKeyInfo {
static SameClassConsumerLookupKey getEmptyKey() {
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static SameClassConsumerLookupKey getTombstoneKey() {
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<mlir::Operation*>::getHashValue(key.sourceOp),
key.resultIndex,
key.classId);
}
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
return lhs == rhs;
}
};
struct WholeBatchAssemblyLookupKey {
mlir::Operation* sourceOp = nullptr;
size_t resultIndex = 0;
ClassId classId = 0;
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
}
};
struct WholeBatchAssemblyLookupKeyInfo {
static WholeBatchAssemblyLookupKey getEmptyKey() {
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static WholeBatchAssemblyLookupKey getTombstoneKey() {
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<mlir::Operation*>::getHashValue(key.sourceOp),
key.resultIndex,
key.classId);
}
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
return lhs == rhs;
}
};
using ClassSlotKey = std::pair<ClassId, SlotId>;
struct ProjectedBatchInputKey {
mlir::Operation* consumerOp = nullptr;
unsigned inputIndex = 0;
bool operator==(const ProjectedBatchInputKey& other) const {
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
}
};
struct ProjectedBatchInputKeyInfo {
static ProjectedBatchInputKey getEmptyKey() {
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
}
static ProjectedBatchInputKey getTombstoneKey() {
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
}
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
return llvm::hash_combine(key.consumerOp, key.inputIndex);
}
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
};
} // namespace onnx_mlir::spatial
@@ -0,0 +1,104 @@
#include "ProjectedFragments.hpp"
#include "mlir/IR/BuiltinTypes.h"
namespace onnx_mlir::spatial {
static mlir::FailureOr<mlir::RankedTensorType> getPackedBatchTensorType(mlir::Type laneType, size_t laneCount) {
auto tensorType = mlir::dyn_cast<mlir::RankedTensorType>(laneType);
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
return mlir::failure();
llvm::SmallVector<int64_t, 4> shape(tensorType.getShape());
shape[0] *= static_cast<int64_t>(laneCount);
return mlir::RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding());
}
unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef<int64_t> loopTripCounts) {
unsigned fragmentsPerLogicalSlot = 1;
for (int64_t tripCount : loopTripCounts) {
assert(tripCount > 0 && "projected loop trip counts must be positive");
fragmentsPerLogicalSlot *= static_cast<unsigned>(tripCount);
}
return fragmentsPerLogicalSlot;
}
mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout) {
if (!layout.fragmentType || layout.fragmentShape.empty())
return anchor->emitError("projected fragment layout is missing fragment type metadata");
if (layout.fragmentShape.size() != static_cast<size_t>(layout.fragmentType.getRank()))
return anchor->emitError("projected fragment layout rank does not match fragment type");
if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0)
return anchor->emitError("projected fragment layout has an invalid fragment count");
if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0)
return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots");
return mlir::success();
}
mlir::FailureOr<mlir::RankedTensorType>
getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount) {
auto packedType = getPackedBatchTensorType(fragmentType, payloadFragmentCount);
if (mlir::failed(packedType)) {
anchor->emitError("cannot create projected payload type");
return mlir::failure();
}
return *packedType;
}
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4>
buildProjectedFragmentOffsetsByDim(llvm::ArrayRef<llvm::SmallVector<int64_t, 4>> fragmentOffsets, size_t rank) {
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim(rank);
for (llvm::ArrayRef<int64_t> offsets : fragmentOffsets) {
assert(offsets.size() == rank && "projected offset rank mismatch");
for (size_t dim = 0; dim < rank; ++dim)
fragmentOffsetsByDim[dim].push_back(offsets[dim]);
}
return fragmentOffsetsByDim;
}
mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor,
const ProjectedTransferDescriptor& descriptor) {
if (mlir::failed(verifyProjectedFragmentLayout(anchor, descriptor.layout)))
return mlir::failure();
if (!descriptor.payloadType)
return anchor->emitError("projected transfer descriptor is missing payload type");
if (descriptor.fragmentOffsets.empty())
return anchor->emitError("projected transfer descriptor expected at least one fragment offset");
if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size())
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
for (llvm::ArrayRef<int64_t> dimOffsets : descriptor.fragmentOffsetsByDim)
if (dimOffsets.size() != descriptor.fragmentOffsets.size())
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
for (llvm::ArrayRef<int64_t> offsets : descriptor.fragmentOffsets)
if (offsets.size() != descriptor.layout.fragmentShape.size())
return anchor->emitError("projected transfer offset rank does not match fragment rank");
return mlir::success();
}
mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor,
const ProjectedTransferDescriptor& descriptor,
const MessageVector& messages) {
if (mlir::failed(verifyProjectedTransferDescriptor(anchor, descriptor)))
return mlir::failure();
if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size())
return anchor->emitError("projected send descriptor metadata is inconsistent");
return mlir::success();
}
mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor,
ProjectedTransferDescriptor& descriptor) {
descriptor.fragmentOffsetsByDim =
buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size());
auto payloadType =
getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount);
if (mlir::failed(payloadType))
return mlir::failure();
if (descriptor.payloadType && descriptor.payloadType != *payloadType)
return anchor->emitError("projected transfer descriptor payload type does not match projected layout");
descriptor.payloadType = *payloadType;
return verifyProjectedTransferDescriptor(anchor, descriptor);
}
} // namespace onnx_mlir::spatial
@@ -0,0 +1,87 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
#include "MergeMessages.hpp"
#include "MergeScheduleKeys.hpp"
namespace onnx_mlir::spatial {
struct ProjectedFragmentLayout {
mlir::RankedTensorType fragmentType;
llvm::SmallVector<int64_t, 4> fragmentShape;
unsigned fragmentsPerLogicalSlot = 1;
unsigned payloadFragmentCount = 1;
llvm::SmallVector<int64_t, 4> loopLowerBounds;
llvm::SmallVector<int64_t, 4> loopSteps;
llvm::SmallVector<int64_t, 4> loopTripCounts;
};
struct StaticProjectedLoopInfo {
mlir::BlockArgument iv;
int64_t lowerBound = 0;
int64_t step = 1;
int64_t tripCount = 1;
};
struct ProjectedTransferDescriptor {
ProjectedBatchInputKey inputKey;
mlir::Operation* extractOp = nullptr;
ProjectedFragmentLayout layout;
mlir::RankedTensorType payloadType;
llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 16> fragmentOffsets;
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim;
};
struct ProjectedExtractReplacement {
mlir::Value payload;
ProjectedFragmentLayout layout;
};
struct PendingProjectedHostOutputFragment {
mlir::Value originalOutput;
ClassId sourceClass = 0;
ProducerKey producerKey;
unsigned publicationResultIndex = 0;
int64_t sourceFragmentOrdinal = 0;
int64_t sourceElementOffset = 0;
llvm::SmallVector<int64_t, 4> offsets;
llvm::SmallVector<int64_t, 4> sizes;
llvm::SmallVector<int64_t, 4> strides;
uint32_t sourceLane = 0;
mlir::Location loc;
};
struct AffineProjectedInputSliceMatch {
mlir::tensor::ExtractSliceOp extract;
mlir::RankedTensorType sourceType;
mlir::RankedTensorType fragmentType;
llvm::SmallVector<int64_t, 4> fragmentShape;
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
llvm::SmallVector<StaticProjectedLoopInfo, 4> loops;
};
unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef<int64_t> loopTripCounts);
mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout);
mlir::FailureOr<mlir::RankedTensorType>
getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount);
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4>
buildProjectedFragmentOffsetsByDim(llvm::ArrayRef<llvm::SmallVector<int64_t, 4>> fragmentOffsets, size_t rank);
mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor,
const ProjectedTransferDescriptor& descriptor);
mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor,
const ProjectedTransferDescriptor& descriptor,
const MessageVector& messages);
mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor,
ProjectedTransferDescriptor& descriptor);
} // namespace onnx_mlir::spatial
@@ -12,7 +12,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using CPU = int; using CPU = int;
-7
View File
@@ -23,13 +23,6 @@ function(add_pim_unittest test_name)
set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest) set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest)
endfunction() endfunction()
add_pim_unittest(LabeledListTest
LabeledListTest.cpp
LINK_LIBS PRIVATE
OMPimCommon
)
add_pim_unittest(PimMemoryLivenessPlannerTest add_pim_unittest(PimMemoryLivenessPlannerTest
PimMemoryLivenessPlannerTest.cpp PimMemoryLivenessPlannerTest.cpp
-162
View File
@@ -1,162 +0,0 @@
#include <cassert>
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <vector>
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
using onnx_mlir::LabeledList;
using onnx_mlir::LabeledListNode;
namespace {
struct TestNode : public LabeledListNode<TestNode> {
explicit TestNode(int id)
: id(id) {}
int id;
};
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
auto expectedIt = expectedOrder.begin();
for (auto& node : list) {
assert(expectedIt != expectedOrder.end());
assert(node.id == *expectedIt);
++expectedIt;
}
assert(expectedIt == expectedOrder.end());
}
void assertStrictlyIncreasingLabels(LabeledList<TestNode>& list) {
auto it = list.begin();
if (it == list.end())
return;
auto previousLabel = it->getOrderLabel();
++it;
for (; it != list.end(); ++it) {
assert(previousLabel < it->getOrderLabel());
previousLabel = it->getOrderLabel();
}
}
int testLabeledListBasicMutation() {
std::cout << "testLabeledListBasicMutation:" << std::endl;
LabeledList<TestNode> list;
TestNode n1(1);
TestNode n2(2);
TestNode n3(3);
TestNode n4(4);
TestNode n5(5);
assert(list.empty());
assert(list.front() == nullptr);
assert(list.back() == nullptr);
assert(!list.contains(&n1));
assert(LabeledList<TestNode>::previous(&n1) == nullptr);
assert(LabeledList<TestNode>::next(&n1) == nullptr);
list.pushBack(&n1);
list.pushBack(&n3);
list.insertAfter(&n1, &n2);
list.pushFront(&n4);
list.insertBefore(nullptr, &n5);
assert(list.size() == 5);
assert(list.front() == &n4);
assert(list.back() == &n5);
assert(list.contains(&n2));
assertOrder(list, {4, 1, 2, 3, 5});
assert(LabeledList<TestNode>::next(&n4) == &n1);
assert(LabeledList<TestNode>::previous(&n1) == &n4);
assert(LabeledList<TestNode>::next(&n5) == nullptr);
assert(list.comesBefore(&n1, &n3));
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
list.moveBefore(&n5, &n2);
assertOrder(list, {4, 1, 5, 2, 3});
list.moveAfter(&n4, &n3);
assertOrder(list, {1, 5, 2, 3, 4});
list.remove(&n2);
assert(!n2.isLinked());
assert(!list.contains(&n2));
assertOrder(list, {1, 5, 3, 4});
list.clear();
assert(list.empty());
assert(list.size() == 0);
assert(list.front() == nullptr);
assert(list.back() == nullptr);
assert(!n1.isLinked());
assert(!n3.isLinked());
assert(!n4.isLinked());
assert(!n5.isLinked());
return 0;
}
int testLabeledListRelabelingAndNoopMoves() {
std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl;
constexpr int insertedNodeCount = 80;
LabeledList<TestNode> list;
TestNode head(0);
TestNode tail(999);
std::vector<TestNode> insertedNodes;
insertedNodes.reserve(insertedNodeCount);
for (int i = 0; i < insertedNodeCount; ++i)
insertedNodes.emplace_back(i + 1);
list.pushBack(&head);
list.pushBack(&tail);
for (auto& node : insertedNodes)
list.insertAfter(&head, &node);
assert(list.size() == insertedNodeCount + 2);
assert(list.front() == &head);
assert(list.back() == &tail);
assert(LabeledList<TestNode>::previous(&head) == nullptr);
assert(LabeledList<TestNode>::next(&tail) == nullptr);
assertStrictlyIncreasingLabels(list);
auto* firstInserted = LabeledList<TestNode>::next(&head);
auto* secondInserted = LabeledList<TestNode>::next(firstInserted);
list.moveBefore(firstInserted, secondInserted);
list.moveAfter(&head, nullptr);
list.moveAfter(&tail, LabeledList<TestNode>::previous(&tail));
assert(list.front() == &head);
assert(list.back() == &tail);
assert(firstInserted == &insertedNodes.back());
assert(secondInserted == &insertedNodes[insertedNodeCount - 2]);
assertStrictlyIncreasingLabels(list);
int expectedId = insertedNodeCount;
auto it = std::next(list.begin());
for (; it != list.end() && &*it != &tail; ++it, --expectedId)
assert(it->id == expectedId);
assert(expectedId == 0);
list.clear();
return 0;
}
} // namespace
int main(int argc, char* argv[]) {
(void) argc;
(void) argv;
int failures = 0;
failures += testLabeledListBasicMutation();
failures += testLabeledListRelabelingAndNoopMoves();
if (failures != 0) {
std::cerr << failures << " test failures\n";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}