This commit is contained in:
@@ -5,6 +5,7 @@ add_pim_library(OMPimCommon
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/IndexingUtils.cpp
|
||||
IR/LoopUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
IR/SubviewUtils.cpp
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -9,6 +10,65 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
mlir::FailureOr<std::optional<int32_t>>
|
||||
getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) {
|
||||
auto coreIdAttr = computeOp->getAttrOfType<mlir::IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
return std::optional<int32_t> {};
|
||||
if (coreIdAttr.getInt() < 0) {
|
||||
computeOp.emitOpError() << fieldName << " must be non-negative";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), computeOp, fieldName);
|
||||
if (mlir::failed(checkedCoreId))
|
||||
return mlir::failure();
|
||||
return std::optional<int32_t> {*checkedCoreId};
|
||||
}
|
||||
|
||||
mlir::FailureOr<int32_t> getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) {
|
||||
auto coreId = getOptionalScheduledCoreId(computeOp, fieldName);
|
||||
if (mlir::failed(coreId))
|
||||
return mlir::failure();
|
||||
if (!*coreId) {
|
||||
computeOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdAttrName;
|
||||
return mlir::failure();
|
||||
}
|
||||
return **coreId;
|
||||
}
|
||||
|
||||
mlir::FailureOr<std::optional<llvm::SmallVector<int32_t>>>
|
||||
getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) {
|
||||
auto coreIdsAttr = computeBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return std::optional<llvm::SmallVector<int32_t>> {};
|
||||
|
||||
llvm::SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(coreIdsAttr.size());
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef()) {
|
||||
if (coreId < 0) {
|
||||
computeBatchOp.emitOpError() << fieldName << " values must be non-negative";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto checkedCoreId = pim::checkedI32(static_cast<int64_t>(coreId), computeBatchOp, fieldName);
|
||||
if (mlir::failed(checkedCoreId))
|
||||
return mlir::failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
}
|
||||
return std::optional<llvm::SmallVector<int32_t>> {std::move(coreIds)};
|
||||
}
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int32_t>>
|
||||
getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) {
|
||||
auto coreIds = getOptionalScheduledBatchCoreIds(computeBatchOp, fieldName);
|
||||
if (mlir::failed(coreIds))
|
||||
return mlir::failure();
|
||||
if (!*coreIds) {
|
||||
computeBatchOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdsAttrName;
|
||||
return mlir::failure();
|
||||
}
|
||||
return std::move(**coreIds);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
|
||||
@@ -3,12 +3,26 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
mlir::FailureOr<std::optional<int32_t>>
|
||||
getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<int32_t> getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<std::optional<llvm::SmallVector<int32_t>>>
|
||||
getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int32_t>>
|
||||
getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName);
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -163,4 +166,72 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(llvm::ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(llvm::ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> permutation) {
|
||||
llvm::SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(llvm::ArrayRef<int64_t> permutation) {
|
||||
llvm::SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>>
|
||||
getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr, int64_t rank) {
|
||||
llvm::SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return mlir::failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
llvm::SmallVector<bool> seen(rank, false);
|
||||
for (mlir::IntegerAttr attr : permAttr->getAsRange<mlir::IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return mlir::failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) {
|
||||
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank) {
|
||||
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef<int64_t> shape) {
|
||||
llvm::SmallVector<mlir::OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,15 +2,23 @@
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
@@ -36,4 +44,67 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides);
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ilist_node.h"
|
||||
#include "llvm/ADT/simple_ilist.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledList;
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledListNode : public llvm::ilist_node<NodeT> {
|
||||
friend class LabeledList<NodeT>;
|
||||
|
||||
public:
|
||||
using Label = uint64_t;
|
||||
|
||||
LabeledListNode() = default;
|
||||
LabeledListNode(const LabeledListNode&) = delete;
|
||||
LabeledListNode(LabeledListNode&&) = default;
|
||||
LabeledListNode& operator=(LabeledListNode&&) = delete;
|
||||
|
||||
~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); }
|
||||
|
||||
bool isLinked() const { return owner_ != nullptr; }
|
||||
Label getOrderLabel() const { return label; }
|
||||
|
||||
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt) { return lft.label < rgt.label; }
|
||||
|
||||
private:
|
||||
const void* owner_ = nullptr;
|
||||
Label label = 0;
|
||||
};
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledList {
|
||||
|
||||
using Label = typename NodeT::Label;
|
||||
|
||||
static constexpr Label kLowerSentinel = 0;
|
||||
static constexpr Label kUpperSentinel = std::numeric_limits<Label>::max();
|
||||
static constexpr Label kRelabelGap = 2;
|
||||
|
||||
public:
|
||||
using List = llvm::simple_ilist<NodeT>;
|
||||
using Iterator = typename List::iterator;
|
||||
using RIterator = typename List::reverse_iterator;
|
||||
using ConstIterator = typename List::const_iterator;
|
||||
|
||||
LabeledList() = default;
|
||||
LabeledList(const LabeledList&) = delete;
|
||||
LabeledList& operator=(const LabeledList&) = delete;
|
||||
LabeledList(LabeledList&&) = delete;
|
||||
LabeledList& operator=(LabeledList&&) = delete;
|
||||
|
||||
~LabeledList() { clear(); }
|
||||
|
||||
bool empty() const { return size_ == 0; }
|
||||
size_t size() const { return size_; }
|
||||
|
||||
NodeT* front() { return empty() ? nullptr : &nodes_.front(); }
|
||||
const NodeT* front() const { return empty() ? nullptr : &nodes_.front(); }
|
||||
|
||||
NodeT* back() { return empty() ? nullptr : &nodes_.back(); }
|
||||
const NodeT* back() const { return empty() ? nullptr : &nodes_.back(); }
|
||||
|
||||
static NodeT* previous(NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
auto* list = owner(node);
|
||||
auto it = node->getIterator();
|
||||
if (it == list->nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
static const NodeT* previous(const NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
const auto* list = owner(node);
|
||||
auto it = const_cast<NodeT*>(node)->getIterator();
|
||||
if (it == list->nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
static NodeT* next(NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
auto* list = owner(node);
|
||||
auto it = std::next(node->getIterator());
|
||||
if (it == list->nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
static const NodeT* next(const NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
const auto* list = owner(node);
|
||||
auto it = std::next(const_cast<NodeT*>(node)->getIterator());
|
||||
if (it == list->nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
|
||||
|
||||
Label getOrderLabel(const NodeT* node) const {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
return node->label;
|
||||
}
|
||||
|
||||
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
|
||||
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
|
||||
return lhs->label < rhs->label;
|
||||
}
|
||||
|
||||
void pushFront(NodeT* node) { insertBefore(front(), node); }
|
||||
|
||||
void pushBack(NodeT* node) { insertBefore(nullptr, node); }
|
||||
|
||||
void insertBefore(NodeT* nextNode, NodeT* node) {
|
||||
assert(node && "cannot insert a null node");
|
||||
assert(!node->owner_ && "node is already linked");
|
||||
assert(nextNode == nullptr || contains(nextNode));
|
||||
|
||||
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
|
||||
nodes_.insert(nextIt, *node);
|
||||
node->owner_ = this;
|
||||
++size_;
|
||||
assignLabel(getIteratorFor(node));
|
||||
}
|
||||
|
||||
void insertAfter(NodeT* prevNode, NodeT* node) {
|
||||
assert(prevNode == nullptr || contains(prevNode));
|
||||
if (prevNode == nullptr)
|
||||
insertBefore(front(), node);
|
||||
else
|
||||
insertBefore(next(prevNode), node);
|
||||
}
|
||||
|
||||
void remove(NodeT* node) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
nodes_.remove(*node);
|
||||
node->owner_ = nullptr;
|
||||
node->label = 0;
|
||||
--size_;
|
||||
}
|
||||
|
||||
void moveBefore(NodeT* node, NodeT* nextNode) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
assert(nextNode == nullptr || contains(nextNode));
|
||||
|
||||
Iterator nodeIt = getIteratorFor(node);
|
||||
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
|
||||
if (nodeIt == nextIt || std::next(nodeIt) == nextIt)
|
||||
return;
|
||||
|
||||
nodes_.splice(nextIt, nodes_, nodeIt);
|
||||
assignLabel(getIteratorFor(node));
|
||||
}
|
||||
|
||||
void moveAfter(NodeT* node, NodeT* prevNode) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
assert(prevNode == nullptr || contains(prevNode));
|
||||
|
||||
Iterator nextIt = prevNode ? std::next(getIteratorFor(prevNode)) : nodes_.begin();
|
||||
if (getIteratorFor(node) == nextIt)
|
||||
return;
|
||||
moveBefore(node, nextIt == nodes_.end() ? nullptr : &*nextIt);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
while (!nodes_.empty()) {
|
||||
NodeT* node = &nodes_.front();
|
||||
node->owner_ = nullptr;
|
||||
node->label = 0;
|
||||
nodes_.remove(*node);
|
||||
}
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
Iterator begin() { return nodes_.begin(); }
|
||||
Iterator end() { return nodes_.end(); }
|
||||
|
||||
RIterator rbegin() { return nodes_.rbegin(); }
|
||||
RIterator rend() { return nodes_.rend(); }
|
||||
|
||||
private:
|
||||
static const LabeledList* owner(const NodeT* node) { return static_cast<const LabeledList*>(node->owner_); }
|
||||
static LabeledList* owner(NodeT* node) { return static_cast<LabeledList*>(const_cast<void*>(node->owner_)); }
|
||||
|
||||
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
|
||||
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
|
||||
|
||||
static Label labelGap(Label lower, Label upper) {
|
||||
assert(lower < upper && "labels must be strictly ordered");
|
||||
return upper - lower;
|
||||
}
|
||||
|
||||
static bool hasMidpoint(Label lower, Label upper) { return labelGap(lower, upper) > 1; }
|
||||
|
||||
static bool hasRelabelSlack(Label lower, Label upper, size_t nodeCount) {
|
||||
Label gap = labelGap(lower, upper);
|
||||
return gap / static_cast<Label>(nodeCount + 1) >= kRelabelGap;
|
||||
}
|
||||
|
||||
Iterator getIteratorFor(NodeT* node) { return node->getIterator(); }
|
||||
ConstIterator getiteratorFor(const NodeT* node) const { return node->getIterator(); }
|
||||
|
||||
NodeT* previousNode(Iterator it) {
|
||||
if (it == nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
const NodeT* previousNode(ConstIterator it) const {
|
||||
if (it == nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
NodeT* nextNode(Iterator it) {
|
||||
++it;
|
||||
if (it == nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
const NodeT* nextNode(ConstIterator it) const {
|
||||
++it;
|
||||
if (it == nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
void assignLabel(Iterator it) {
|
||||
Label lower = lowerLabel(previousNode(it));
|
||||
Label upper = upperLabel(nextNode(it));
|
||||
if (hasMidpoint(lower, upper)) {
|
||||
(*it).label = lower + static_cast<Label>(labelGap(lower, upper) / 2);
|
||||
return;
|
||||
}
|
||||
|
||||
relabelAround(it);
|
||||
}
|
||||
|
||||
void relabelAround(Iterator center) {
|
||||
size_t targetCount = 1;
|
||||
while (true) {
|
||||
Iterator left = center;
|
||||
Iterator right = center;
|
||||
size_t actualCount = 1;
|
||||
expandWindow(center, targetCount, left, right, actualCount);
|
||||
|
||||
Label lower = lowerLabel(previousNode(left));
|
||||
Label upper = upperLabel(nextNode(right));
|
||||
if (hasRelabelSlack(lower, upper, actualCount)) {
|
||||
relabelWindow(left, actualCount, lower, upper);
|
||||
return;
|
||||
}
|
||||
|
||||
if (left == nodes_.begin() && nextNode(right) == nullptr) {
|
||||
assert(hasRelabelSlack(lower, upper, actualCount) && "label space exhausted");
|
||||
relabelWindow(left, actualCount, lower, upper);
|
||||
return;
|
||||
}
|
||||
|
||||
targetCount *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
void expandWindow(Iterator center, size_t targetCount, Iterator& left, Iterator& right, size_t& actualCount) {
|
||||
left = center;
|
||||
right = center;
|
||||
actualCount = 1;
|
||||
|
||||
while (actualCount < targetCount && (left != nodes_.begin() || nextNode(right) != nullptr)) {
|
||||
if (left != nodes_.begin()) {
|
||||
--left;
|
||||
++actualCount;
|
||||
if (actualCount == targetCount)
|
||||
break;
|
||||
}
|
||||
if (nextNode(right) != nullptr) {
|
||||
++right;
|
||||
++actualCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void relabelWindow(Iterator left, size_t nodeCount, Label lower, Label upper) {
|
||||
assert(nodeCount > 0 && "relabel window must not be empty");
|
||||
Label step = labelGap(lower, upper) / static_cast<Label>(nodeCount + 1);
|
||||
assert(step >= 1 && "relabel step must be positive");
|
||||
|
||||
Iterator it = left;
|
||||
for (size_t index = 1; index <= nodeCount; ++index) {
|
||||
(*it).label = lower + step * index;
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
List nodes_;
|
||||
size_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
|
||||
@@ -10,6 +10,7 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Post.cpp
|
||||
Patterns/GeneratedConversion.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/ConvGeometry.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
@@ -30,7 +31,7 @@ add_pim_library(OMONNXToSpatial
|
||||
LowerSpatialPlansPass.cpp
|
||||
Common/AttributeUtils.cpp
|
||||
Common/ComputeRegionBuilder.cpp
|
||||
Common/IndexingUtils.cpp
|
||||
Common/MatrixProductLowering.cpp
|
||||
Common/ShapeTilingUtils.cpp
|
||||
Common/WeightMaterialization.cpp
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "MatrixProductLowering.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
#include "MatrixProductLowering.hpp"
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
Value createPaddedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Value createZeroPaddedTensor(mlir::Value value,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value createPaddedInputCompute(mlir::Value input,
|
||||
mlir::RankedTensorType paddedInputType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,9 +3,6 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
@@ -15,73 +12,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
|
||||
@@ -1,89 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
/// at most `sliceSize` elements.
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvLoweringState {
|
||||
Value x;
|
||||
Value w;
|
||||
Value b;
|
||||
RankedTensorType xType;
|
||||
RankedTensorType wType;
|
||||
RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct ConvLoweringDecision {
|
||||
PimConvLoweringType strategy;
|
||||
std::string reason;
|
||||
@@ -108,19 +56,6 @@ struct PreparedConvInput {
|
||||
RankedTensorType type;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
struct ConvStrategyEstimate {
|
||||
uint64_t estimatedMvmCount = 0;
|
||||
uint64_t estimatedReductionVAddCount = 0;
|
||||
@@ -291,9 +226,6 @@ static FailureOr<Value> createRowStripPackedRows(Value rows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc);
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
|
||||
|
||||
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
|
||||
@@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
|
||||
return estimate;
|
||||
}
|
||||
|
||||
static ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
static std::string formatShape(ArrayRef<int64_t> dims) {
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
@@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows,
|
||||
int64_t inputHeight,
|
||||
int64_t kernelH,
|
||||
int64_t strideH,
|
||||
int64_t dilationH,
|
||||
int64_t padTop) {
|
||||
const int64_t rawBegin = outputRows.begin * strideH - padTop;
|
||||
const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(inputHeight, rawEnd)};
|
||||
}
|
||||
|
||||
static bool covers(RowInterval acquired, RowInterval needed) {
|
||||
return acquired.begin <= needed.begin && acquired.end >= needed.end;
|
||||
}
|
||||
|
||||
static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
RowInterval neededInputRows = computeConvInputRowsForOutputRows(
|
||||
outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin);
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = neededInputRows;
|
||||
demand.acquiredInputRows = neededInputRows;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
return demand;
|
||||
}
|
||||
|
||||
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
|
||||
if (state.batchSize != 1) {
|
||||
failureReason = "unsupported_batch";
|
||||
@@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp,
|
||||
rewriteConvLoweringReport(reportEntries);
|
||||
}
|
||||
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location
|
||||
});
|
||||
}
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
|
||||
assert(value > 0 && "expected positive value");
|
||||
limit = std::min(value, limit);
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
#include "ConvGeometry.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(state.xHeight, rawEnd)};
|
||||
}
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = computeConvInputRowsForOutputRows(outputRows, state);
|
||||
demand.acquiredInputRows = demand.neededInputRows;
|
||||
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
demand.acquiredInputRows = demand.neededInputRows;
|
||||
return demand;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,86 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ConvLoweringState {
|
||||
mlir::Value x;
|
||||
mlir::Value w;
|
||||
mlir::Value b;
|
||||
mlir::RankedTensorType xType;
|
||||
mlir::RankedTensorType wType;
|
||||
mlir::RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -87,28 +87,6 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
rewriter.getInsertionBlock()->getParentOp());
|
||||
}
|
||||
|
||||
static Value
|
||||
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedConstantMatrix(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
@@ -232,22 +210,6 @@ static Value extractATile(
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
|
||||
@@ -255,42 +255,6 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
return createONNXTranspose(resultType, {0, 2, 1});
|
||||
}
|
||||
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedBatchedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
|
||||
ArrayRef<int64_t> sourceBatchShape,
|
||||
ArrayRef<int64_t> targetBatchShape,
|
||||
@@ -1055,7 +1019,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
auto paddedRhs =
|
||||
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
|
||||
if (succeeded(paddedRhs)) {
|
||||
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||
Value paddedLhs = createPaddedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
|
||||
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
shapeInfo->outType.getElementType());
|
||||
|
||||
@@ -29,100 +29,6 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
});
|
||||
}
|
||||
|
||||
static bool isMaterializableExternalTensorOp(Operation* op) {
|
||||
return isa<spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatExtractRowsOp,
|
||||
tensor::ExtractSliceOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CollapseShapeOp>(op);
|
||||
}
|
||||
|
||||
//TODO REMOVE THIS UGLY FIX
|
||||
//TODO: Remove this helper once compute_batch external tensor captures are
|
||||
// fixed at the producer side.
|
||||
//
|
||||
// This function is a temporary SpatialToPim repair path. It clones selected
|
||||
// external tensor producers, such as channel_receive and tensor view/slice ops,
|
||||
// into the new pim.core_batch body when the old spat.compute_batch body refers
|
||||
// to tensor values defined outside the batch.
|
||||
//
|
||||
// The real invariant should be stronger:
|
||||
//
|
||||
// A spat.compute_batch body must not capture external tensor values.
|
||||
// Every tensor used inside the body must be either:
|
||||
// - a compute_batch block argument,
|
||||
// - defined inside the compute_batch body,
|
||||
// - or a legal constant-like value.
|
||||
//
|
||||
// If this invariant is violated, the responsible producer, most likely merge
|
||||
// schedule materialization, should emit verifier-clean Spatial IR instead of
|
||||
// relying on SpatialToPim to clone external producer chains later.
|
||||
//
|
||||
// After that producer-side fix:
|
||||
// 1. remove isMaterializableExternalTensorOp,
|
||||
// 2. remove materializeExternalTensorValue,
|
||||
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
|
||||
// tensor operand,
|
||||
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
|
||||
// before SpatialToPim.
|
||||
//
|
||||
// Be careful not to replace every external tensor capture with a normal
|
||||
// compute_batch input blindly: host-backed tensors and explicit inter-core
|
||||
// communication have different semantics. In particular, channel_receive-like
|
||||
// values should be materialized through the communication model, not silently
|
||||
// treated as host inputs.
|
||||
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Block& oldBlock,
|
||||
Value value,
|
||||
IRMapping& mapper) {
|
||||
if (mapper.contains(value))
|
||||
return mapper.lookup(value);
|
||||
|
||||
if (!isa<TensorType>(value.getType()))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return failure();
|
||||
|
||||
if (definingOp->getBlock() == &oldBlock)
|
||||
return failure();
|
||||
|
||||
if (!isMaterializableExternalTensorOp(definingOp))
|
||||
return failure();
|
||||
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
|
||||
if (succeeded(materializedOperand))
|
||||
mapper.map(operand, *materializedOperand);
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(*definingOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
|
||||
return mapper.lookup(value);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
||||
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
++fallbackCoreId;
|
||||
}
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
@@ -386,7 +292,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
auto coreIds = getRequiredScheduledBatchCoreIds(computeBatchOp, "spatial compute_batch core id");
|
||||
if (failed(coreIds))
|
||||
return failure();
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
@@ -638,9 +544,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
|
||||
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic =
|
||||
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
@@ -141,17 +142,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
||||
}
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeOp, "fallback spatial compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
++fallbackCoreId;
|
||||
return *checkedCoreId;
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
@@ -311,7 +301,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCom
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId);
|
||||
auto checkedCoreId = getRequiredScheduledCoreId(computeOp, "spatial compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id");
|
||||
|
||||
@@ -44,121 +44,29 @@ using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
|
||||
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
|
||||
continue;
|
||||
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
std::string nameStem;
|
||||
llvm::raw_string_ostream nameStream(nameStem);
|
||||
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
|
||||
nameStream.flush();
|
||||
|
||||
std::string symbolName = nameStem;
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
|
||||
symbolName = (nameStem + "_" + Twine(suffix++)).str();
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(symbolName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
zeroAttr,
|
||||
rewriter.getUnitAttr(),
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
auto sizeAttr =
|
||||
pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static bool isHostBackedMemRefValue(Value value) {
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return isa<memref::GetGlobalOp>(definingOp);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isHostBackedTensorValue(Value value) {
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
|
||||
extractSliceOp.getMixedOffsets(),
|
||||
extractSliceOp.getStaticSizes(),
|
||||
extractSliceOp.getStaticStrides())) {
|
||||
return false;
|
||||
}
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
|
||||
return isHostBackedMemRefValue(toTensorOp.getBuffer());
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
createZeroPaddedTensor(IRRewriter& rewriter, Location loc, Value value, RankedTensorType resultType) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
@@ -169,26 +77,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
if (failed(zeroed))
|
||||
return failure();
|
||||
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0);
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
if (isHostBackedTensorValue(vector)) {
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
||||
return createZeroPaddedTensor(rewriter, loc, vector, paddedType);
|
||||
}
|
||||
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
outputTensors.clear();
|
||||
operationsToRemove.clear();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -362,7 +254,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
bool hasFailure = false;
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
@@ -371,7 +262,7 @@ LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func:
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
if (failed(paddedInput)) {
|
||||
hasFailure = true;
|
||||
return WalkResult::interrupt();
|
||||
|
||||
@@ -36,7 +36,6 @@ private:
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
@@ -8,7 +8,9 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsCanonicalization.cpp
|
||||
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/HostOutputFinalization.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/ProjectedFragments.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
#include "HostOutputFinalization.hpp"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "MaterializedClassState.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
if (state.pendingProjectedHostOutputFragments.empty())
|
||||
return success();
|
||||
|
||||
DenseMap<Value, SmallVector<PendingProjectedHostOutputFragment*, 16>> byOutput;
|
||||
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments)
|
||||
byOutput[fragment.originalOutput].push_back(&fragment);
|
||||
|
||||
SmallVector<Value, 8> outputs;
|
||||
outputs.reserve(byOutput.size());
|
||||
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
|
||||
if (!returnOp)
|
||||
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
|
||||
|
||||
DenseSet<Value> seenOutputs;
|
||||
for (Value returned : returnOp.getOperands()) {
|
||||
if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second)
|
||||
continue;
|
||||
outputs.push_back(returned);
|
||||
}
|
||||
if (outputs.size() != byOutput.size())
|
||||
return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs");
|
||||
|
||||
for (Value originalOutput : outputs) {
|
||||
if (isa_and_present<SpatScheduledCompute, SpatScheduledComputeBatch>(originalOutput.getDefiningOp())) {
|
||||
return state.func.emitError(
|
||||
"projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result");
|
||||
}
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return state.func.emitError("projected host output must have static ranked tensor type");
|
||||
|
||||
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
|
||||
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
|
||||
const PendingProjectedHostOutputFragment* rhs) {
|
||||
if (lhs->sourceClass != rhs->sourceClass)
|
||||
return lhs->sourceClass < rhs->sourceClass;
|
||||
if (lhs->publicationResultIndex != rhs->publicationResultIndex)
|
||||
return lhs->publicationResultIndex < rhs->publicationResultIndex;
|
||||
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
|
||||
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
|
||||
return std::lexicographical_compare(lhs->offsets.begin(),
|
||||
lhs->offsets.end(),
|
||||
rhs->offsets.begin(),
|
||||
rhs->offsets.end());
|
||||
});
|
||||
|
||||
state.rewriter.setInsertionPoint(returnOp);
|
||||
Location loc = fragments.front()->loc;
|
||||
SmallVector<Value, 16> blueprintOperands;
|
||||
SmallVector<int64_t, 16> fragmentOperandIndices;
|
||||
SmallVector<int64_t, 16> fragmentSourceOffsets;
|
||||
SmallVector<int64_t, 64> flatOffsets;
|
||||
SmallVector<int64_t, 64> flatSizes;
|
||||
SmallVector<int64_t, 64> flatStrides;
|
||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
if (fragmentRecord->sourceClass >= state.classes.size())
|
||||
return state.func.emitError("projected host output fragment references an invalid source class");
|
||||
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) {
|
||||
return sourceClass.op->emitError("projected host output fragment references an invalid publication result")
|
||||
<< " sourceClass=" << sourceClass.id
|
||||
<< " resultIndex=" << fragmentRecord->publicationResultIndex
|
||||
<< " resultCount=" << sourceClass.op->getNumResults();
|
||||
}
|
||||
|
||||
Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex);
|
||||
auto [operandIt, inserted] =
|
||||
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(blueprintOperands.size()));
|
||||
if (inserted)
|
||||
blueprintOperands.push_back(operand);
|
||||
fragmentOperandIndices.push_back(operandIt->second);
|
||||
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
|
||||
llvm::append_range(flatOffsets, fragmentRecord->offsets);
|
||||
llvm::append_range(flatSizes, fragmentRecord->sizes);
|
||||
llvm::append_range(flatStrides, fragmentRecord->strides);
|
||||
|
||||
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
|
||||
if (!operandType || !operandType.hasStaticShape())
|
||||
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
|
||||
}
|
||||
|
||||
if (blueprintOperands.empty())
|
||||
return state.func.emitError("missing projected host output fragments");
|
||||
|
||||
Value input = blueprintOperands.front();
|
||||
ValueRange extraFragments = ValueRange(blueprintOperands).drop_front();
|
||||
auto blueprint = SpatBlueprintOp::create(
|
||||
state.rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
extraFragments,
|
||||
state.rewriter.getStringAttr("nchw"),
|
||||
state.rewriter.getStringAttr("fragmented"),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatSizes),
|
||||
state.rewriter.getStringAttr("identity"),
|
||||
state.rewriter.getStringAttr("fragment_assembly"),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatStrides),
|
||||
state.rewriter.getStringAttr("disjoint"),
|
||||
state.rewriter.getStringAttr("complete"));
|
||||
|
||||
state.hostReplacements[originalOutput] = blueprint.getOutput();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct MaterializerState;
|
||||
|
||||
mlir::LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state);
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
+122
-791
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,252 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "MergeMessages.hpp"
|
||||
#include "MergeScheduleKeys.hpp"
|
||||
#include "ProjectedFragments.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct MaterializedClass {
|
||||
ClassId id = 0;
|
||||
llvm::SmallVector<CpuId, 8> cpus;
|
||||
mlir::Operation* op = nullptr;
|
||||
mlir::Block* body = nullptr;
|
||||
bool isBatch = false;
|
||||
|
||||
llvm::DenseMap<CpuId, unsigned> cpuToLane;
|
||||
llvm::SmallVector<mlir::Value, 8> weights;
|
||||
llvm::SmallVector<mlir::Value, 8> inputs;
|
||||
llvm::SmallVector<mlir::Value, 4> hostOutputs;
|
||||
llvm::DenseMap<mlir::Value, unsigned> publicationOutputToResultIndex;
|
||||
llvm::DenseMap<mlir::Value, mlir::BlockArgument> weightArgs;
|
||||
llvm::DenseMap<mlir::Value, mlir::BlockArgument> inputArgs;
|
||||
llvm::DenseMap<mlir::Value, unsigned> hostOutputToResultIndex;
|
||||
};
|
||||
|
||||
struct PackedScalarRunSlot {
|
||||
llvm::SmallVector<ProducerKey, 8> keys;
|
||||
};
|
||||
|
||||
enum class PackedScalarRunKind {
|
||||
Materialized,
|
||||
DeferredReceive,
|
||||
DeferredLocalCompute
|
||||
};
|
||||
|
||||
struct PackedScalarRunValue {
|
||||
ClassId targetClass = 0;
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
PackedScalarRunKind kind = PackedScalarRunKind::Materialized;
|
||||
|
||||
mlir::Value packed;
|
||||
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<PackedScalarRunSlot, 8> slots;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct IndexedBatchRunValue {
|
||||
ClassId targetClass = 0;
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
mlir::Value packed;
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<PackedScalarRunSlot, 8> slots;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct LogicalSlotRange {
|
||||
SlotId start = 0;
|
||||
SlotId count = 0;
|
||||
};
|
||||
|
||||
struct MaterializationRunSlot {
|
||||
llvm::SmallVector<ComputeInstance, 8> peers;
|
||||
};
|
||||
|
||||
using MaterializationRun = llvm::SmallVector<MaterializationRunSlot, 8>;
|
||||
|
||||
struct OutputDestinationGroup {
|
||||
llvm::SmallVector<size_t, 4> resultIndices;
|
||||
llvm::SmallVector<ClassId, 4> destinationClasses;
|
||||
};
|
||||
|
||||
struct BatchRunSendPlan {
|
||||
size_t resultIndex = 0;
|
||||
ClassId destinationClass = 0;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
enum class TensorDemandActionKind {
|
||||
DestinationFanout,
|
||||
SameClassIndexedFragment,
|
||||
TerminalBlueprintPublication,
|
||||
WholeTensorBarrier
|
||||
};
|
||||
|
||||
enum class WholeTensorBarrierReason {
|
||||
FunctionReturnWithoutBlueprint,
|
||||
DenseLogicalConsumer
|
||||
};
|
||||
|
||||
struct TensorDemandAction {
|
||||
TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout;
|
||||
std::optional<ClassId> destinationClass;
|
||||
std::optional<WholeTensorBarrierReason> barrierReason;
|
||||
};
|
||||
|
||||
struct RunOutputDemand {
|
||||
size_t resultIndex = 0;
|
||||
mlir::Value originalOutput;
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<TensorDemandAction, 4> actions;
|
||||
};
|
||||
|
||||
struct CompactRunPlan {
|
||||
llvm::SmallVector<RunOutputDemand, 4> outputs;
|
||||
};
|
||||
|
||||
enum class BatchInputDemandKind {
|
||||
LaneFragment,
|
||||
ProjectedFragment,
|
||||
WholeTensorBarrier
|
||||
};
|
||||
|
||||
struct BatchInputDemand {
|
||||
BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment;
|
||||
std::optional<ProducerKey> wholeTensorProducer;
|
||||
};
|
||||
|
||||
struct CloneIndexingContext {
|
||||
std::optional<mlir::Value> runSlotIndex;
|
||||
std::optional<mlir::Value> projectionSlotIndex;
|
||||
};
|
||||
|
||||
struct MaterializerState;
|
||||
|
||||
class AvailableValueStore {
|
||||
public:
|
||||
struct ExactBatchFragmentRecord {
|
||||
ProducerKey key;
|
||||
mlir::Value value;
|
||||
};
|
||||
|
||||
void record(ProducerKey key, ClassId classId, mlir::Value value) {
|
||||
exactValues[key][classId] = value;
|
||||
|
||||
auto batch = mlir::dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||
if (!batch || key.instance.laneCount == 0)
|
||||
return;
|
||||
|
||||
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
|
||||
llvm::SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
|
||||
for (ExactBatchFragmentRecord& record : bucket) {
|
||||
if (!(record.key == key))
|
||||
continue;
|
||||
record.value = value;
|
||||
return;
|
||||
}
|
||||
bucket.push_back({key, value});
|
||||
}
|
||||
|
||||
void recordPackedRun(PackedScalarRunValue run) {
|
||||
size_t runIndex = packedScalarRuns.size();
|
||||
packedScalarRuns.push_back(std::move(run));
|
||||
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
|
||||
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
|
||||
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
|
||||
}
|
||||
|
||||
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
|
||||
|
||||
std::optional<mlir::Value> lookupExact(ProducerKey key, ClassId classId) const;
|
||||
std::optional<mlir::Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
|
||||
|
||||
llvm::ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = packedRunsByProducerResultClass.find(key);
|
||||
if (it == packedRunsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
llvm::ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = exactBatchFragmentsByProducerResultClass.find(key);
|
||||
if (it == exactBatchFragmentsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
|
||||
|
||||
private:
|
||||
std::optional<mlir::Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
|
||||
llvm::DenseMap<ProducerKey, llvm::DenseMap<ClassId, mlir::Value>, ProducerKeyInfo> exactValues;
|
||||
llvm::SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
|
||||
llvm::SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
|
||||
llvm::DenseMap<WholeBatchAssemblyLookupKey,
|
||||
llvm::SmallVector<ExactBatchFragmentRecord, 16>,
|
||||
WholeBatchAssemblyLookupKeyInfo>
|
||||
exactBatchFragmentsByProducerResultClass;
|
||||
llvm::DenseMap<WholeBatchAssemblyLookupKey, llvm::SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||
packedRunsByProducerResultClass;
|
||||
};
|
||||
|
||||
struct MaterializerState {
|
||||
mlir::func::FuncOp func;
|
||||
const MergeScheduleResult& schedule;
|
||||
mlir::IRRewriter rewriter;
|
||||
mlir::OperationFolder constantFolder;
|
||||
int64_t& nextChannelId;
|
||||
llvm::SmallVector<MaterializedClass, 8> classes;
|
||||
llvm::DenseMap<CpuId, ClassId> cpuToClass;
|
||||
llvm::DenseMap<CpuId, llvm::SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
|
||||
llvm::DenseMap<ComputeInstance, LogicalSlotRange> scheduledInstanceToLogicalSlots;
|
||||
llvm::DenseMap<ComputeInstance, ComputeInstance> logicalInstanceToScheduledChunk;
|
||||
llvm::DenseSet<ClassSlotKey> materializedLogicalSlots;
|
||||
|
||||
llvm::DenseMap<ProducerKey, llvm::SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||
llvm::DenseMap<SameClassConsumerLookupKey, llvm::SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
|
||||
sameClassConsumerIndex;
|
||||
llvm::DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo>
|
||||
projectedInputMatches;
|
||||
llvm::DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
|
||||
llvm::DenseMap<mlir::Value, bool> liveExternalUseCache;
|
||||
llvm::DenseMap<mlir::Operation*, llvm::SmallVector<mlir::Type, 4>> batchOutputFragmentTypesCache;
|
||||
llvm::DenseMap<ComputeInstance, llvm::SmallVector<mlir::Value, 4>, llvm::DenseMapInfo<ComputeInstance>>
|
||||
computeInstanceOutputsCache;
|
||||
llvm::DenseMap<ProducerKey, llvm::DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo>
|
||||
projectedTransfers;
|
||||
llvm::DenseMap<mlir::Operation*, llvm::DenseMap<ClassId, ProjectedExtractReplacement>>
|
||||
projectedExtractReplacements;
|
||||
AvailableValueStore availableValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> hostReplacements;
|
||||
llvm::DenseMap<mlir::Value, ClassId> hostOutputOwners;
|
||||
llvm::SmallVector<PendingProjectedHostOutputFragment, 32> pendingProjectedHostOutputFragments;
|
||||
llvm::DenseSet<mlir::Operation*> oldComputeOps;
|
||||
|
||||
MaterializerState(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
||||
: func(func),
|
||||
schedule(schedule),
|
||||
rewriter(func.getContext()),
|
||||
constantFolder(func.getContext()),
|
||||
nextChannelId(nextChannelId) {}
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -28,6 +28,7 @@
|
||||
#include "Scheduling/ComputeGraph.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
@@ -43,16 +44,6 @@ using namespace onnx_mlir::compact_asm;
|
||||
using SpatCompute = spatial::SpatGraphCompute;
|
||||
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return std::nullopt;
|
||||
return *checkedCoreId;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||
if (!compute->hasOneUse())
|
||||
return false;
|
||||
@@ -213,8 +204,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreId = getComputeCoreId(spatCompute))
|
||||
coreIds.push_back(*coreId);
|
||||
auto coreId = getOptionalScheduledCoreId(spatCompute, "merge compute core id");
|
||||
if (failed(coreId))
|
||||
return;
|
||||
if (*coreId)
|
||||
coreIds.push_back(**coreId);
|
||||
uint64_t computeId = totalComputeOps++;
|
||||
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
uint64_t maxConcatOperands = 0;
|
||||
@@ -234,8 +228,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
||||
auto optionalCoreIds = getOptionalScheduledBatchCoreIds(batch, "merge compute_batch core id");
|
||||
if (failed(optionalCoreIds))
|
||||
return;
|
||||
if (*optionalCoreIds)
|
||||
coreIds = std::move(**optionalCoreIds);
|
||||
collectedData.push_back(
|
||||
{nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
|
||||
totalComputeOps += 1;
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
using CpuId = size_t;
|
||||
|
||||
inline mlir::FailureOr<int32_t> getCheckedCoreId(mlir::Operation* anchor, CpuId cpu, llvm::StringRef fieldName) {
|
||||
return pim::checkedI32(static_cast<uint64_t>(cpu), anchor, fieldName);
|
||||
}
|
||||
|
||||
inline mlir::FailureOr<llvm::SmallVector<int32_t, 8>>
|
||||
getCheckedCoreIds(mlir::Operation* anchor, llvm::ArrayRef<CpuId> cpus, llvm::StringRef fieldName) {
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
coreIds.reserve(cpus.size());
|
||||
for (CpuId cpu : cpus) {
|
||||
auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName);
|
||||
if (mlir::failed(checkedCoreId))
|
||||
return mlir::failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
}
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
struct MessageVector {
|
||||
llvm::SmallVector<int64_t, 16> channelIds;
|
||||
llvm::SmallVector<int32_t, 16> sourceCoreIds;
|
||||
llvm::SmallVector<int32_t, 16> targetCoreIds;
|
||||
|
||||
size_t size() const { return channelIds.size(); }
|
||||
bool empty() const { return channelIds.empty(); }
|
||||
|
||||
mlir::LogicalResult verify(mlir::Operation* anchor) const {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return anchor->emitError("message metadata is inconsistent");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) {
|
||||
channelIds.push_back(channelId);
|
||||
sourceCoreIds.push_back(sourceCoreId);
|
||||
targetCoreIds.push_back(targetCoreId);
|
||||
}
|
||||
|
||||
void append(llvm::ArrayRef<int64_t> channels, llvm::ArrayRef<int32_t> sources, llvm::ArrayRef<int32_t> targets) {
|
||||
assert(channels.size() == sources.size() && "channel/source count mismatch");
|
||||
assert(channels.size() == targets.size() && "channel/target count mismatch");
|
||||
llvm::append_range(channelIds, channels);
|
||||
llvm::append_range(sourceCoreIds, sources);
|
||||
llvm::append_range(targetCoreIds, targets);
|
||||
}
|
||||
|
||||
MessageVector slice(size_t offset, size_t count) const {
|
||||
MessageVector result;
|
||||
result.append(llvm::ArrayRef<int64_t>(channelIds).slice(offset, count),
|
||||
llvm::ArrayRef<int32_t>(sourceCoreIds).slice(offset, count),
|
||||
llvm::ArrayRef<int32_t>(targetCoreIds).slice(offset, count));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,134 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
using ClassId = size_t;
|
||||
using SlotId = size_t;
|
||||
|
||||
struct ProducerKey {
|
||||
ComputeInstance instance;
|
||||
size_t resultIndex = 0;
|
||||
|
||||
bool operator==(const ProducerKey& other) const {
|
||||
return instance == other.instance && resultIndex == other.resultIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProducerKeyInfo {
|
||||
static ProducerKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<ComputeInstance>::getEmptyKey(), std::numeric_limits<size_t>::max()};
|
||||
}
|
||||
|
||||
static ProducerKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<ComputeInstance>::getTombstoneKey(), std::numeric_limits<size_t>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProducerKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<ComputeInstance>::getHashValue(key.instance), key.resultIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKey {
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const SameClassConsumerLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKeyInfo {
|
||||
static SameClassConsumerLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static SameClassConsumerLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<mlir::Operation*>::getHashValue(key.sourceOp),
|
||||
key.resultIndex,
|
||||
key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKey {
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKeyInfo {
|
||||
static WholeBatchAssemblyLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static WholeBatchAssemblyLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<mlir::Operation*>::getHashValue(key.sourceOp),
|
||||
key.resultIndex,
|
||||
key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
using ClassSlotKey = std::pair<ClassId, SlotId>;
|
||||
|
||||
struct ProjectedBatchInputKey {
|
||||
mlir::Operation* consumerOp = nullptr;
|
||||
unsigned inputIndex = 0;
|
||||
|
||||
bool operator==(const ProjectedBatchInputKey& other) const {
|
||||
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProjectedBatchInputKeyInfo {
|
||||
static ProjectedBatchInputKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static ProjectedBatchInputKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
|
||||
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,104 @@
|
||||
#include "ProjectedFragments.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
static mlir::FailureOr<mlir::RankedTensorType> getPackedBatchTensorType(mlir::Type laneType, size_t laneCount) {
|
||||
auto tensorType = mlir::dyn_cast<mlir::RankedTensorType>(laneType);
|
||||
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> shape(tensorType.getShape());
|
||||
shape[0] *= static_cast<int64_t>(laneCount);
|
||||
return mlir::RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
}
|
||||
|
||||
unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef<int64_t> loopTripCounts) {
|
||||
unsigned fragmentsPerLogicalSlot = 1;
|
||||
for (int64_t tripCount : loopTripCounts) {
|
||||
assert(tripCount > 0 && "projected loop trip counts must be positive");
|
||||
fragmentsPerLogicalSlot *= static_cast<unsigned>(tripCount);
|
||||
}
|
||||
return fragmentsPerLogicalSlot;
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout) {
|
||||
if (!layout.fragmentType || layout.fragmentShape.empty())
|
||||
return anchor->emitError("projected fragment layout is missing fragment type metadata");
|
||||
if (layout.fragmentShape.size() != static_cast<size_t>(layout.fragmentType.getRank()))
|
||||
return anchor->emitError("projected fragment layout rank does not match fragment type");
|
||||
if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0)
|
||||
return anchor->emitError("projected fragment layout has an invalid fragment count");
|
||||
if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0)
|
||||
return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::FailureOr<mlir::RankedTensorType>
|
||||
getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount) {
|
||||
auto packedType = getPackedBatchTensorType(fragmentType, payloadFragmentCount);
|
||||
if (mlir::failed(packedType)) {
|
||||
anchor->emitError("cannot create projected payload type");
|
||||
return mlir::failure();
|
||||
}
|
||||
return *packedType;
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4>
|
||||
buildProjectedFragmentOffsetsByDim(llvm::ArrayRef<llvm::SmallVector<int64_t, 4>> fragmentOffsets, size_t rank) {
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim(rank);
|
||||
for (llvm::ArrayRef<int64_t> offsets : fragmentOffsets) {
|
||||
assert(offsets.size() == rank && "projected offset rank mismatch");
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
fragmentOffsetsByDim[dim].push_back(offsets[dim]);
|
||||
}
|
||||
return fragmentOffsetsByDim;
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor) {
|
||||
if (mlir::failed(verifyProjectedFragmentLayout(anchor, descriptor.layout)))
|
||||
return mlir::failure();
|
||||
if (!descriptor.payloadType)
|
||||
return anchor->emitError("projected transfer descriptor is missing payload type");
|
||||
if (descriptor.fragmentOffsets.empty())
|
||||
return anchor->emitError("projected transfer descriptor expected at least one fragment offset");
|
||||
if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size())
|
||||
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
|
||||
for (llvm::ArrayRef<int64_t> dimOffsets : descriptor.fragmentOffsetsByDim)
|
||||
if (dimOffsets.size() != descriptor.fragmentOffsets.size())
|
||||
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
|
||||
for (llvm::ArrayRef<int64_t> offsets : descriptor.fragmentOffsets)
|
||||
if (offsets.size() != descriptor.layout.fragmentShape.size())
|
||||
return anchor->emitError("projected transfer offset rank does not match fragment rank");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
const MessageVector& messages) {
|
||||
if (mlir::failed(verifyProjectedTransferDescriptor(anchor, descriptor)))
|
||||
return mlir::failure();
|
||||
if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size())
|
||||
return anchor->emitError("projected send descriptor metadata is inconsistent");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
ProjectedTransferDescriptor& descriptor) {
|
||||
descriptor.fragmentOffsetsByDim =
|
||||
buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size());
|
||||
|
||||
auto payloadType =
|
||||
getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount);
|
||||
if (mlir::failed(payloadType))
|
||||
return mlir::failure();
|
||||
if (descriptor.payloadType && descriptor.payloadType != *payloadType)
|
||||
return anchor->emitError("projected transfer descriptor payload type does not match projected layout");
|
||||
descriptor.payloadType = *payloadType;
|
||||
|
||||
return verifyProjectedTransferDescriptor(anchor, descriptor);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,87 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "MergeMessages.hpp"
|
||||
#include "MergeScheduleKeys.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct ProjectedFragmentLayout {
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<int64_t, 4> fragmentShape;
|
||||
unsigned fragmentsPerLogicalSlot = 1;
|
||||
unsigned payloadFragmentCount = 1;
|
||||
llvm::SmallVector<int64_t, 4> loopLowerBounds;
|
||||
llvm::SmallVector<int64_t, 4> loopSteps;
|
||||
llvm::SmallVector<int64_t, 4> loopTripCounts;
|
||||
};
|
||||
|
||||
struct StaticProjectedLoopInfo {
|
||||
mlir::BlockArgument iv;
|
||||
int64_t lowerBound = 0;
|
||||
int64_t step = 1;
|
||||
int64_t tripCount = 1;
|
||||
};
|
||||
|
||||
struct ProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
mlir::Operation* extractOp = nullptr;
|
||||
ProjectedFragmentLayout layout;
|
||||
mlir::RankedTensorType payloadType;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 16> fragmentOffsets;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim;
|
||||
};
|
||||
|
||||
struct ProjectedExtractReplacement {
|
||||
mlir::Value payload;
|
||||
ProjectedFragmentLayout layout;
|
||||
};
|
||||
|
||||
struct PendingProjectedHostOutputFragment {
|
||||
mlir::Value originalOutput;
|
||||
ClassId sourceClass = 0;
|
||||
ProducerKey producerKey;
|
||||
unsigned publicationResultIndex = 0;
|
||||
int64_t sourceFragmentOrdinal = 0;
|
||||
int64_t sourceElementOffset = 0;
|
||||
llvm::SmallVector<int64_t, 4> offsets;
|
||||
llvm::SmallVector<int64_t, 4> sizes;
|
||||
llvm::SmallVector<int64_t, 4> strides;
|
||||
uint32_t sourceLane = 0;
|
||||
mlir::Location loc;
|
||||
};
|
||||
|
||||
struct AffineProjectedInputSliceMatch {
|
||||
mlir::tensor::ExtractSliceOp extract;
|
||||
mlir::RankedTensorType sourceType;
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<int64_t, 4> fragmentShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
|
||||
llvm::SmallVector<StaticProjectedLoopInfo, 4> loops;
|
||||
};
|
||||
|
||||
unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef<int64_t> loopTripCounts);
|
||||
mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout);
|
||||
mlir::FailureOr<mlir::RankedTensorType>
|
||||
getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount);
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4>
|
||||
buildProjectedFragmentOffsetsByDim(llvm::ArrayRef<llvm::SmallVector<int64_t, 4>> fragmentOffsets, size_t rank);
|
||||
mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor);
|
||||
mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
const MessageVector& messages);
|
||||
mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
ProjectedTransferDescriptor& descriptor);
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -12,7 +12,6 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using CPU = int;
|
||||
|
||||
@@ -23,13 +23,6 @@ function(add_pim_unittest test_name)
|
||||
set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest)
|
||||
endfunction()
|
||||
|
||||
add_pim_unittest(LabeledListTest
|
||||
LabeledListTest.cpp
|
||||
|
||||
LINK_LIBS PRIVATE
|
||||
OMPimCommon
|
||||
)
|
||||
|
||||
add_pim_unittest(PimMemoryLivenessPlannerTest
|
||||
PimMemoryLivenessPlannerTest.cpp
|
||||
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
|
||||
using onnx_mlir::LabeledList;
|
||||
using onnx_mlir::LabeledListNode;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestNode : public LabeledListNode<TestNode> {
|
||||
explicit TestNode(int id)
|
||||
: id(id) {}
|
||||
|
||||
int id;
|
||||
};
|
||||
|
||||
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
|
||||
auto expectedIt = expectedOrder.begin();
|
||||
for (auto& node : list) {
|
||||
assert(expectedIt != expectedOrder.end());
|
||||
assert(node.id == *expectedIt);
|
||||
++expectedIt;
|
||||
}
|
||||
assert(expectedIt == expectedOrder.end());
|
||||
}
|
||||
|
||||
void assertStrictlyIncreasingLabels(LabeledList<TestNode>& list) {
|
||||
auto it = list.begin();
|
||||
if (it == list.end())
|
||||
return;
|
||||
|
||||
auto previousLabel = it->getOrderLabel();
|
||||
++it;
|
||||
for (; it != list.end(); ++it) {
|
||||
assert(previousLabel < it->getOrderLabel());
|
||||
previousLabel = it->getOrderLabel();
|
||||
}
|
||||
}
|
||||
|
||||
int testLabeledListBasicMutation() {
|
||||
std::cout << "testLabeledListBasicMutation:" << std::endl;
|
||||
|
||||
LabeledList<TestNode> list;
|
||||
TestNode n1(1);
|
||||
TestNode n2(2);
|
||||
TestNode n3(3);
|
||||
TestNode n4(4);
|
||||
TestNode n5(5);
|
||||
|
||||
assert(list.empty());
|
||||
assert(list.front() == nullptr);
|
||||
assert(list.back() == nullptr);
|
||||
assert(!list.contains(&n1));
|
||||
assert(LabeledList<TestNode>::previous(&n1) == nullptr);
|
||||
assert(LabeledList<TestNode>::next(&n1) == nullptr);
|
||||
|
||||
list.pushBack(&n1);
|
||||
list.pushBack(&n3);
|
||||
list.insertAfter(&n1, &n2);
|
||||
list.pushFront(&n4);
|
||||
list.insertBefore(nullptr, &n5);
|
||||
|
||||
assert(list.size() == 5);
|
||||
assert(list.front() == &n4);
|
||||
assert(list.back() == &n5);
|
||||
assert(list.contains(&n2));
|
||||
assertOrder(list, {4, 1, 2, 3, 5});
|
||||
assert(LabeledList<TestNode>::next(&n4) == &n1);
|
||||
assert(LabeledList<TestNode>::previous(&n1) == &n4);
|
||||
assert(LabeledList<TestNode>::next(&n5) == nullptr);
|
||||
assert(list.comesBefore(&n1, &n3));
|
||||
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
|
||||
|
||||
list.moveBefore(&n5, &n2);
|
||||
assertOrder(list, {4, 1, 5, 2, 3});
|
||||
|
||||
list.moveAfter(&n4, &n3);
|
||||
assertOrder(list, {1, 5, 2, 3, 4});
|
||||
|
||||
list.remove(&n2);
|
||||
assert(!n2.isLinked());
|
||||
assert(!list.contains(&n2));
|
||||
assertOrder(list, {1, 5, 3, 4});
|
||||
|
||||
list.clear();
|
||||
assert(list.empty());
|
||||
assert(list.size() == 0);
|
||||
assert(list.front() == nullptr);
|
||||
assert(list.back() == nullptr);
|
||||
assert(!n1.isLinked());
|
||||
assert(!n3.isLinked());
|
||||
assert(!n4.isLinked());
|
||||
assert(!n5.isLinked());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int testLabeledListRelabelingAndNoopMoves() {
|
||||
std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl;
|
||||
|
||||
constexpr int insertedNodeCount = 80;
|
||||
LabeledList<TestNode> list;
|
||||
TestNode head(0);
|
||||
TestNode tail(999);
|
||||
std::vector<TestNode> insertedNodes;
|
||||
insertedNodes.reserve(insertedNodeCount);
|
||||
for (int i = 0; i < insertedNodeCount; ++i)
|
||||
insertedNodes.emplace_back(i + 1);
|
||||
|
||||
list.pushBack(&head);
|
||||
list.pushBack(&tail);
|
||||
for (auto& node : insertedNodes)
|
||||
list.insertAfter(&head, &node);
|
||||
|
||||
assert(list.size() == insertedNodeCount + 2);
|
||||
assert(list.front() == &head);
|
||||
assert(list.back() == &tail);
|
||||
assert(LabeledList<TestNode>::previous(&head) == nullptr);
|
||||
assert(LabeledList<TestNode>::next(&tail) == nullptr);
|
||||
assertStrictlyIncreasingLabels(list);
|
||||
|
||||
auto* firstInserted = LabeledList<TestNode>::next(&head);
|
||||
auto* secondInserted = LabeledList<TestNode>::next(firstInserted);
|
||||
list.moveBefore(firstInserted, secondInserted);
|
||||
list.moveAfter(&head, nullptr);
|
||||
list.moveAfter(&tail, LabeledList<TestNode>::previous(&tail));
|
||||
|
||||
assert(list.front() == &head);
|
||||
assert(list.back() == &tail);
|
||||
assert(firstInserted == &insertedNodes.back());
|
||||
assert(secondInserted == &insertedNodes[insertedNodeCount - 2]);
|
||||
assertStrictlyIncreasingLabels(list);
|
||||
|
||||
int expectedId = insertedNodeCount;
|
||||
auto it = std::next(list.begin());
|
||||
for (; it != list.end() && &*it != &tail; ++it, --expectedId)
|
||||
assert(it->id == expectedId);
|
||||
assert(expectedId == 0);
|
||||
list.clear();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
(void) argc;
|
||||
(void) argv;
|
||||
|
||||
int failures = 0;
|
||||
failures += testLabeledListBasicMutation();
|
||||
failures += testLabeledListRelabelingAndNoopMoves();
|
||||
if (failures != 0) {
|
||||
std::cerr << failures << " test failures\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
Reference in New Issue
Block a user