This commit is contained in:
Submodule backend-simulators/pim/pimsim-nn updated: 3e3442b663...6d3b898e6b
+1
-1
Submodule onnx-mlir updated: 82018d7ce5...eb54c2afc4
@@ -5,9 +5,11 @@ add_pim_library(OMPimCommon
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/IndexingUtils.cpp
|
||||
IR/LoopUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
IR/SubviewUtils.cpp
|
||||
IR/TensorSliceUtils.cpp
|
||||
IR/WeightUtils.cpp
|
||||
Support/CheckedArithmetic.cpp
|
||||
Support/DebugDump.cpp
|
||||
|
||||
@@ -69,6 +69,15 @@ Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineAddConst(RewriterBase& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
if (offset == 0)
|
||||
return value;
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
||||
@@ -90,6 +99,34 @@ Value affineFloorDivConst(
|
||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineAddModConst(
|
||||
RewriterBase& rewriter, Location loc, Value value, int64_t offset, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
||||
if (divisor == 1)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
AffineExpr expr = d0;
|
||||
if (offset != 0)
|
||||
expr = expr + offset;
|
||||
return createOrFoldAffineApply(rewriter, loc, expr % divisor, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineAddFloorDivConst(
|
||||
RewriterBase& rewriter, Location loc, Value value, int64_t offset, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
|
||||
if (divisor == 1)
|
||||
return offset == 0 ? value : affineAddConst(rewriter, loc, value, offset, constantAnchor);
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
AffineExpr expr = d0;
|
||||
if (offset != 0)
|
||||
expr = expr + offset;
|
||||
return createOrFoldAffineApply(rewriter, loc, expr.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
||||
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
||||
return constant.getValue();
|
||||
|
||||
@@ -29,6 +29,12 @@ mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
|
||||
int64_t multiplier,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineAddConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t offset,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
@@ -41,6 +47,20 @@ mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineAddModConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t offset,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineAddFloorDivConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t offset,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
llvm::FailureOr<int64_t>
|
||||
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -9,6 +10,65 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
mlir::FailureOr<std::optional<int32_t>>
|
||||
getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) {
|
||||
auto coreIdAttr = computeOp->getAttrOfType<mlir::IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
return std::optional<int32_t> {};
|
||||
if (coreIdAttr.getInt() < 0) {
|
||||
computeOp.emitOpError() << fieldName << " must be non-negative";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), computeOp, fieldName);
|
||||
if (mlir::failed(checkedCoreId))
|
||||
return mlir::failure();
|
||||
return std::optional<int32_t> {*checkedCoreId};
|
||||
}
|
||||
|
||||
mlir::FailureOr<int32_t> getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName) {
|
||||
auto coreId = getOptionalScheduledCoreId(computeOp, fieldName);
|
||||
if (mlir::failed(coreId))
|
||||
return mlir::failure();
|
||||
if (!*coreId) {
|
||||
computeOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdAttrName;
|
||||
return mlir::failure();
|
||||
}
|
||||
return **coreId;
|
||||
}
|
||||
|
||||
mlir::FailureOr<std::optional<llvm::SmallVector<int32_t>>>
|
||||
getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) {
|
||||
auto coreIdsAttr = computeBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return std::optional<llvm::SmallVector<int32_t>> {};
|
||||
|
||||
llvm::SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(coreIdsAttr.size());
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef()) {
|
||||
if (coreId < 0) {
|
||||
computeBatchOp.emitOpError() << fieldName << " values must be non-negative";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto checkedCoreId = pim::checkedI32(static_cast<int64_t>(coreId), computeBatchOp, fieldName);
|
||||
if (mlir::failed(checkedCoreId))
|
||||
return mlir::failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
}
|
||||
return std::optional<llvm::SmallVector<int32_t>> {std::move(coreIds)};
|
||||
}
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int32_t>>
|
||||
getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName) {
|
||||
auto coreIds = getOptionalScheduledBatchCoreIds(computeBatchOp, fieldName);
|
||||
if (mlir::failed(coreIds))
|
||||
return mlir::failure();
|
||||
if (!*coreIds) {
|
||||
computeBatchOp.emitOpError() << "missing required " << onnx_mlir::kCoreIdsAttrName;
|
||||
return mlir::failure();
|
||||
}
|
||||
return std::move(**coreIds);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
|
||||
@@ -3,12 +3,26 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
mlir::FailureOr<std::optional<int32_t>>
|
||||
getOptionalScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<int32_t> getRequiredScheduledCoreId(spatial::SpatScheduledCompute computeOp, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<std::optional<llvm::SmallVector<int32_t>>>
|
||||
getOptionalScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int32_t>>
|
||||
getRequiredScheduledBatchCoreIds(spatial::SpatScheduledComputeBatch computeBatchOp, llvm::StringRef fieldName);
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -163,4 +166,80 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(llvm::ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(llvm::ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> permutation) {
|
||||
llvm::SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(llvm::ArrayRef<int64_t> permutation) {
|
||||
llvm::SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>>
|
||||
getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr, int64_t rank) {
|
||||
llvm::SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return mlir::failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
llvm::SmallVector<bool> seen(rank, false);
|
||||
for (mlir::IntegerAttr attr : permAttr->getAsRange<mlir::IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return mlir::failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticIndexAttrs(mlir::Builder& builder, llvm::ArrayRef<int64_t> values) {
|
||||
llvm::SmallVector<mlir::OpFoldResult> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) {
|
||||
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank) {
|
||||
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef<int64_t> shape) {
|
||||
llvm::SmallVector<mlir::OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,15 +2,23 @@
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
@@ -36,4 +44,69 @@ bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides);
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticIndexAttrs(mlir::Builder& builder, llvm::ArrayRef<int64_t> values);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||
Location loc,
|
||||
Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||
|
||||
bool isIdentitySlice =
|
||||
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||
&& strides.size() == rank;
|
||||
if (isIdentitySlice) {
|
||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim
|
||||
|| *staticStride != 1) {
|
||||
isIdentitySlice = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isIdentitySlice)
|
||||
return source;
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
return tensor::InsertSliceOp::create(rewriter,
|
||||
loc,
|
||||
source,
|
||||
dest,
|
||||
offsets,
|
||||
getStaticSizes(rewriter, sourceType.getShape()),
|
||||
getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::RankedTensorType resultType,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,315 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ilist_node.h"
|
||||
#include "llvm/ADT/simple_ilist.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledList;
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledListNode : public llvm::ilist_node<NodeT> {
|
||||
friend class LabeledList<NodeT>;
|
||||
|
||||
public:
|
||||
using Label = uint64_t;
|
||||
|
||||
LabeledListNode() = default;
|
||||
LabeledListNode(const LabeledListNode&) = delete;
|
||||
LabeledListNode(LabeledListNode&&) = default;
|
||||
LabeledListNode& operator=(LabeledListNode&&) = delete;
|
||||
|
||||
~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); }
|
||||
|
||||
bool isLinked() const { return owner_ != nullptr; }
|
||||
Label getOrderLabel() const { return label; }
|
||||
|
||||
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt) { return lft.label < rgt.label; }
|
||||
|
||||
private:
|
||||
const void* owner_ = nullptr;
|
||||
Label label = 0;
|
||||
};
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledList {
|
||||
|
||||
using Label = typename NodeT::Label;
|
||||
|
||||
static constexpr Label kLowerSentinel = 0;
|
||||
static constexpr Label kUpperSentinel = std::numeric_limits<Label>::max();
|
||||
static constexpr Label kRelabelGap = 2;
|
||||
|
||||
public:
|
||||
using List = llvm::simple_ilist<NodeT>;
|
||||
using Iterator = typename List::iterator;
|
||||
using RIterator = typename List::reverse_iterator;
|
||||
using ConstIterator = typename List::const_iterator;
|
||||
|
||||
LabeledList() = default;
|
||||
LabeledList(const LabeledList&) = delete;
|
||||
LabeledList& operator=(const LabeledList&) = delete;
|
||||
LabeledList(LabeledList&&) = delete;
|
||||
LabeledList& operator=(LabeledList&&) = delete;
|
||||
|
||||
~LabeledList() { clear(); }
|
||||
|
||||
bool empty() const { return size_ == 0; }
|
||||
size_t size() const { return size_; }
|
||||
|
||||
NodeT* front() { return empty() ? nullptr : &nodes_.front(); }
|
||||
const NodeT* front() const { return empty() ? nullptr : &nodes_.front(); }
|
||||
|
||||
NodeT* back() { return empty() ? nullptr : &nodes_.back(); }
|
||||
const NodeT* back() const { return empty() ? nullptr : &nodes_.back(); }
|
||||
|
||||
static NodeT* previous(NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
auto* list = owner(node);
|
||||
auto it = node->getIterator();
|
||||
if (it == list->nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
static const NodeT* previous(const NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
const auto* list = owner(node);
|
||||
auto it = const_cast<NodeT*>(node)->getIterator();
|
||||
if (it == list->nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
static NodeT* next(NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
auto* list = owner(node);
|
||||
auto it = std::next(node->getIterator());
|
||||
if (it == list->nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
static const NodeT* next(const NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
const auto* list = owner(node);
|
||||
auto it = std::next(const_cast<NodeT*>(node)->getIterator());
|
||||
if (it == list->nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
|
||||
|
||||
Label getOrderLabel(const NodeT* node) const {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
return node->label;
|
||||
}
|
||||
|
||||
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
|
||||
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
|
||||
return lhs->label < rhs->label;
|
||||
}
|
||||
|
||||
void pushFront(NodeT* node) { insertBefore(front(), node); }
|
||||
|
||||
void pushBack(NodeT* node) { insertBefore(nullptr, node); }
|
||||
|
||||
void insertBefore(NodeT* nextNode, NodeT* node) {
|
||||
assert(node && "cannot insert a null node");
|
||||
assert(!node->owner_ && "node is already linked");
|
||||
assert(nextNode == nullptr || contains(nextNode));
|
||||
|
||||
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
|
||||
nodes_.insert(nextIt, *node);
|
||||
node->owner_ = this;
|
||||
++size_;
|
||||
assignLabel(getIteratorFor(node));
|
||||
}
|
||||
|
||||
void insertAfter(NodeT* prevNode, NodeT* node) {
|
||||
assert(prevNode == nullptr || contains(prevNode));
|
||||
if (prevNode == nullptr)
|
||||
insertBefore(front(), node);
|
||||
else
|
||||
insertBefore(next(prevNode), node);
|
||||
}
|
||||
|
||||
void remove(NodeT* node) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
nodes_.remove(*node);
|
||||
node->owner_ = nullptr;
|
||||
node->label = 0;
|
||||
--size_;
|
||||
}
|
||||
|
||||
void moveBefore(NodeT* node, NodeT* nextNode) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
assert(nextNode == nullptr || contains(nextNode));
|
||||
|
||||
Iterator nodeIt = getIteratorFor(node);
|
||||
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
|
||||
if (nodeIt == nextIt || std::next(nodeIt) == nextIt)
|
||||
return;
|
||||
|
||||
nodes_.splice(nextIt, nodes_, nodeIt);
|
||||
assignLabel(getIteratorFor(node));
|
||||
}
|
||||
|
||||
void moveAfter(NodeT* node, NodeT* prevNode) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
assert(prevNode == nullptr || contains(prevNode));
|
||||
|
||||
Iterator nextIt = prevNode ? std::next(getIteratorFor(prevNode)) : nodes_.begin();
|
||||
if (getIteratorFor(node) == nextIt)
|
||||
return;
|
||||
moveBefore(node, nextIt == nodes_.end() ? nullptr : &*nextIt);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
while (!nodes_.empty()) {
|
||||
NodeT* node = &nodes_.front();
|
||||
node->owner_ = nullptr;
|
||||
node->label = 0;
|
||||
nodes_.remove(*node);
|
||||
}
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
Iterator begin() { return nodes_.begin(); }
|
||||
Iterator end() { return nodes_.end(); }
|
||||
|
||||
RIterator rbegin() { return nodes_.rbegin(); }
|
||||
RIterator rend() { return nodes_.rend(); }
|
||||
|
||||
private:
|
||||
static const LabeledList* owner(const NodeT* node) { return static_cast<const LabeledList*>(node->owner_); }
|
||||
static LabeledList* owner(NodeT* node) { return static_cast<LabeledList*>(const_cast<void*>(node->owner_)); }
|
||||
|
||||
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
|
||||
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
|
||||
|
||||
static Label labelGap(Label lower, Label upper) {
|
||||
assert(lower < upper && "labels must be strictly ordered");
|
||||
return upper - lower;
|
||||
}
|
||||
|
||||
static bool hasMidpoint(Label lower, Label upper) { return labelGap(lower, upper) > 1; }
|
||||
|
||||
static bool hasRelabelSlack(Label lower, Label upper, size_t nodeCount) {
|
||||
Label gap = labelGap(lower, upper);
|
||||
return gap / static_cast<Label>(nodeCount + 1) >= kRelabelGap;
|
||||
}
|
||||
|
||||
Iterator getIteratorFor(NodeT* node) { return node->getIterator(); }
|
||||
ConstIterator getiteratorFor(const NodeT* node) const { return node->getIterator(); }
|
||||
|
||||
NodeT* previousNode(Iterator it) {
|
||||
if (it == nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
const NodeT* previousNode(ConstIterator it) const {
|
||||
if (it == nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
NodeT* nextNode(Iterator it) {
|
||||
++it;
|
||||
if (it == nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
const NodeT* nextNode(ConstIterator it) const {
|
||||
++it;
|
||||
if (it == nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
void assignLabel(Iterator it) {
|
||||
Label lower = lowerLabel(previousNode(it));
|
||||
Label upper = upperLabel(nextNode(it));
|
||||
if (hasMidpoint(lower, upper)) {
|
||||
(*it).label = lower + static_cast<Label>(labelGap(lower, upper) / 2);
|
||||
return;
|
||||
}
|
||||
|
||||
relabelAround(it);
|
||||
}
|
||||
|
||||
void relabelAround(Iterator center) {
|
||||
size_t targetCount = 1;
|
||||
while (true) {
|
||||
Iterator left = center;
|
||||
Iterator right = center;
|
||||
size_t actualCount = 1;
|
||||
expandWindow(center, targetCount, left, right, actualCount);
|
||||
|
||||
Label lower = lowerLabel(previousNode(left));
|
||||
Label upper = upperLabel(nextNode(right));
|
||||
if (hasRelabelSlack(lower, upper, actualCount)) {
|
||||
relabelWindow(left, actualCount, lower, upper);
|
||||
return;
|
||||
}
|
||||
|
||||
if (left == nodes_.begin() && nextNode(right) == nullptr) {
|
||||
assert(hasRelabelSlack(lower, upper, actualCount) && "label space exhausted");
|
||||
relabelWindow(left, actualCount, lower, upper);
|
||||
return;
|
||||
}
|
||||
|
||||
targetCount *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
void expandWindow(Iterator center, size_t targetCount, Iterator& left, Iterator& right, size_t& actualCount) {
|
||||
left = center;
|
||||
right = center;
|
||||
actualCount = 1;
|
||||
|
||||
while (actualCount < targetCount && (left != nodes_.begin() || nextNode(right) != nullptr)) {
|
||||
if (left != nodes_.begin()) {
|
||||
--left;
|
||||
++actualCount;
|
||||
if (actualCount == targetCount)
|
||||
break;
|
||||
}
|
||||
if (nextNode(right) != nullptr) {
|
||||
++right;
|
||||
++actualCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void relabelWindow(Iterator left, size_t nodeCount, Label lower, Label upper) {
|
||||
assert(nodeCount > 0 && "relabel window must not be empty");
|
||||
Label step = labelGap(lower, upper) / static_cast<Label>(nodeCount + 1);
|
||||
assert(step >= 1 && "relabel step must be positive");
|
||||
|
||||
Iterator it = left;
|
||||
for (size_t index = 1; index <= nodeCount; ++index) {
|
||||
(*it).label = lower + step * index;
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
List nodes_;
|
||||
size_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
|
||||
@@ -10,6 +10,7 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Post.cpp
|
||||
Patterns/GeneratedConversion.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/ConvGeometry.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
@@ -30,7 +31,7 @@ add_pim_library(OMONNXToSpatial
|
||||
LowerSpatialPlansPass.cpp
|
||||
Common/AttributeUtils.cpp
|
||||
Common/ComputeRegionBuilder.cpp
|
||||
Common/IndexingUtils.cpp
|
||||
Common/MatrixProductLowering.cpp
|
||||
Common/ShapeTilingUtils.cpp
|
||||
Common/WeightMaterialization.cpp
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "MatrixProductLowering.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
#include "MatrixProductLowering.hpp"
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
Value createPaddedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Value createZeroPaddedTensor(mlir::Value value,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value createPaddedInputCompute(mlir::Value input,
|
||||
mlir::RankedTensorType paddedInputType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,9 +3,6 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
@@ -15,73 +12,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
@@ -147,65 +77,4 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewri
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||
Location loc,
|
||||
Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||
|
||||
bool isIdentitySlice =
|
||||
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||
&& strides.size() == rank;
|
||||
if (isIdentitySlice) {
|
||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
|
||||
isIdentitySlice = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isIdentitySlice)
|
||||
return source;
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
return tensor::InsertSliceOp::create(rewriter,
|
||||
loc,
|
||||
source,
|
||||
dest,
|
||||
offsets,
|
||||
getStaticSizes(rewriter, sourceType.getShape()),
|
||||
getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,89 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
/// at most `sliceSize` elements.
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
@@ -102,21 +28,4 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::RankedTensorType resultType,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvLoweringState {
|
||||
Value x;
|
||||
Value w;
|
||||
Value b;
|
||||
RankedTensorType xType;
|
||||
RankedTensorType wType;
|
||||
RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct ConvLoweringDecision {
|
||||
PimConvLoweringType strategy;
|
||||
std::string reason;
|
||||
@@ -108,19 +56,6 @@ struct PreparedConvInput {
|
||||
RankedTensorType type;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
struct ConvStrategyEstimate {
|
||||
uint64_t estimatedMvmCount = 0;
|
||||
uint64_t estimatedReductionVAddCount = 0;
|
||||
@@ -291,9 +226,6 @@ static FailureOr<Value> createRowStripPackedRows(Value rows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc);
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
|
||||
|
||||
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
|
||||
@@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
|
||||
return estimate;
|
||||
}
|
||||
|
||||
static ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
static std::string formatShape(ArrayRef<int64_t> dims) {
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
@@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows,
|
||||
int64_t inputHeight,
|
||||
int64_t kernelH,
|
||||
int64_t strideH,
|
||||
int64_t dilationH,
|
||||
int64_t padTop) {
|
||||
const int64_t rawBegin = outputRows.begin * strideH - padTop;
|
||||
const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(inputHeight, rawEnd)};
|
||||
}
|
||||
|
||||
static bool covers(RowInterval acquired, RowInterval needed) {
|
||||
return acquired.begin <= needed.begin && acquired.end >= needed.end;
|
||||
}
|
||||
|
||||
static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
RowInterval neededInputRows = computeConvInputRowsForOutputRows(
|
||||
outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin);
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = neededInputRows;
|
||||
demand.acquiredInputRows = neededInputRows;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
return demand;
|
||||
}
|
||||
|
||||
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
|
||||
if (state.batchSize != 1) {
|
||||
failureReason = "unsupported_batch";
|
||||
@@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp,
|
||||
rewriteConvLoweringReport(reportEntries);
|
||||
}
|
||||
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location
|
||||
});
|
||||
}
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
|
||||
assert(value > 0 && "expected positive value");
|
||||
limit = std::min(value, limit);
|
||||
@@ -1324,48 +1184,6 @@ static Value createZeroPaddedTensor(Value value,
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value affineAddConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) {
|
||||
if (offset == 0)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineMulConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
|
||||
if (factor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineFloorDivConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(divisor > 0 && "expected positive affine floordiv divisor");
|
||||
if (divisor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineModConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
|
||||
assert(modulus > 0 && "expected positive affine mod divisor");
|
||||
if (modulus == 1)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value createConvInputPatch(Value input,
|
||||
RankedTensorType patchType,
|
||||
Value batchIndex,
|
||||
@@ -2456,11 +2274,10 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
||||
ValueRange {im2colInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value im2colAcc = iterArgs.front();
|
||||
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
|
||||
Value batchIndex =
|
||||
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||
affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
||||
Value batchPatchIndex =
|
||||
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||
affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
||||
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||
Value inputHeightOffset =
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
#include "ConvGeometry.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(state.xHeight, rawEnd)};
|
||||
}
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = computeConvInputRowsForOutputRows(outputRows, state);
|
||||
demand.acquiredInputRows = demand.neededInputRows;
|
||||
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
demand.acquiredInputRows = demand.neededInputRows;
|
||||
return demand;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,86 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ConvLoweringState {
|
||||
mlir::Value x;
|
||||
mlir::Value w;
|
||||
mlir::Value b;
|
||||
mlir::RankedTensorType xType;
|
||||
mlir::RankedTensorType wType;
|
||||
mlir::RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -87,28 +87,6 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
rewriter.getInsertionBlock()->getParentOp());
|
||||
}
|
||||
|
||||
static Value
|
||||
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedConstantMatrix(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
@@ -232,22 +210,6 @@ static Value extractATile(
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
|
||||
@@ -255,42 +255,6 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
return createONNXTranspose(resultType, {0, 2, 1});
|
||||
}
|
||||
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedBatchedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
|
||||
ArrayRef<int64_t> sourceBatchShape,
|
||||
ArrayRef<int64_t> targetBatchShape,
|
||||
@@ -1055,7 +1019,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
auto paddedRhs =
|
||||
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
|
||||
if (succeeded(paddedRhs)) {
|
||||
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||
Value paddedLhs = createPaddedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
|
||||
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
shapeInfo->outType.getElementType());
|
||||
|
||||
@@ -29,100 +29,6 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
});
|
||||
}
|
||||
|
||||
static bool isMaterializableExternalTensorOp(Operation* op) {
|
||||
return isa<spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatExtractRowsOp,
|
||||
tensor::ExtractSliceOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CollapseShapeOp>(op);
|
||||
}
|
||||
|
||||
//TODO REMOVE THIS UGLY FIX
|
||||
//TODO: Remove this helper once compute_batch external tensor captures are
|
||||
// fixed at the producer side.
|
||||
//
|
||||
// This function is a temporary SpatialToPim repair path. It clones selected
|
||||
// external tensor producers, such as channel_receive and tensor view/slice ops,
|
||||
// into the new pim.core_batch body when the old spat.compute_batch body refers
|
||||
// to tensor values defined outside the batch.
|
||||
//
|
||||
// The real invariant should be stronger:
|
||||
//
|
||||
// A spat.compute_batch body must not capture external tensor values.
|
||||
// Every tensor used inside the body must be either:
|
||||
// - a compute_batch block argument,
|
||||
// - defined inside the compute_batch body,
|
||||
// - or a legal constant-like value.
|
||||
//
|
||||
// If this invariant is violated, the responsible producer, most likely merge
|
||||
// schedule materialization, should emit verifier-clean Spatial IR instead of
|
||||
// relying on SpatialToPim to clone external producer chains later.
|
||||
//
|
||||
// After that producer-side fix:
|
||||
// 1. remove isMaterializableExternalTensorOp,
|
||||
// 2. remove materializeExternalTensorValue,
|
||||
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
|
||||
// tensor operand,
|
||||
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
|
||||
// before SpatialToPim.
|
||||
//
|
||||
// Be careful not to replace every external tensor capture with a normal
|
||||
// compute_batch input blindly: host-backed tensors and explicit inter-core
|
||||
// communication have different semantics. In particular, channel_receive-like
|
||||
// values should be materialized through the communication model, not silently
|
||||
// treated as host inputs.
|
||||
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Block& oldBlock,
|
||||
Value value,
|
||||
IRMapping& mapper) {
|
||||
if (mapper.contains(value))
|
||||
return mapper.lookup(value);
|
||||
|
||||
if (!isa<TensorType>(value.getType()))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return failure();
|
||||
|
||||
if (definingOp->getBlock() == &oldBlock)
|
||||
return failure();
|
||||
|
||||
if (!isMaterializableExternalTensorOp(definingOp))
|
||||
return failure();
|
||||
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
|
||||
if (succeeded(materializedOperand))
|
||||
mapper.map(operand, *materializedOperand);
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(*definingOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
|
||||
return mapper.lookup(value);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
||||
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
++fallbackCoreId;
|
||||
}
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
@@ -386,7 +292,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
auto coreIds = getRequiredScheduledBatchCoreIds(computeBatchOp, "spatial compute_batch core id");
|
||||
if (failed(coreIds))
|
||||
return failure();
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
@@ -638,9 +544,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
|
||||
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic =
|
||||
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
@@ -141,17 +142,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
||||
}
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeOp, "fallback spatial compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
++fallbackCoreId;
|
||||
return *checkedCoreId;
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
@@ -311,7 +301,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCom
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId);
|
||||
auto checkedCoreId = getRequiredScheduledCoreId(computeOp, "spatial compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id");
|
||||
|
||||
@@ -15,22 +15,6 @@ namespace raptor {
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
strides.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
return strides;
|
||||
}
|
||||
|
||||
struct LowerFragmentAssemblyBlueprintPattern
|
||||
: OpConversionPattern<spatial::SpatBlueprintOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -44,121 +44,29 @@ using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
|
||||
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
|
||||
continue;
|
||||
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
std::string nameStem;
|
||||
llvm::raw_string_ostream nameStream(nameStem);
|
||||
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
|
||||
nameStream.flush();
|
||||
|
||||
std::string symbolName = nameStem;
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
|
||||
symbolName = (nameStem + "_" + Twine(suffix++)).str();
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(symbolName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
zeroAttr,
|
||||
rewriter.getUnitAttr(),
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
auto sizeAttr =
|
||||
pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static bool isHostBackedMemRefValue(Value value) {
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return isa<memref::GetGlobalOp>(definingOp);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isHostBackedTensorValue(Value value) {
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
|
||||
extractSliceOp.getMixedOffsets(),
|
||||
extractSliceOp.getStaticSizes(),
|
||||
extractSliceOp.getStaticStrides())) {
|
||||
return false;
|
||||
}
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
|
||||
return isHostBackedMemRefValue(toTensorOp.getBuffer());
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
createZeroPaddedTensor(IRRewriter& rewriter, Location loc, Value value, RankedTensorType resultType) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
@@ -169,26 +77,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
if (failed(zeroed))
|
||||
return failure();
|
||||
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0);
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
if (isHostBackedTensorValue(vector)) {
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
||||
return createZeroPaddedTensor(rewriter, loc, vector, paddedType);
|
||||
}
|
||||
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
outputTensors.clear();
|
||||
operationsToRemove.clear();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -362,7 +254,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
bool hasFailure = false;
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
@@ -371,7 +262,7 @@ LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func:
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
if (failed(paddedInput)) {
|
||||
hasFailure = true;
|
||||
return WalkResult::interrupt();
|
||||
|
||||
@@ -36,7 +36,6 @@ private:
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
@@ -8,7 +8,9 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsCanonicalization.cpp
|
||||
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/HostOutputFinalization.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/ProjectedFragments.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
#include "HostOutputFinalization.hpp"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "MaterializedClassState.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
if (state.pendingProjectedHostOutputFragments.empty())
|
||||
return success();
|
||||
|
||||
DenseMap<Value, SmallVector<PendingProjectedHostOutputFragment*, 16>> byOutput;
|
||||
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments)
|
||||
byOutput[fragment.originalOutput].push_back(&fragment);
|
||||
|
||||
SmallVector<Value, 8> outputs;
|
||||
outputs.reserve(byOutput.size());
|
||||
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
|
||||
if (!returnOp)
|
||||
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
|
||||
|
||||
DenseSet<Value> seenOutputs;
|
||||
for (Value returned : returnOp.getOperands()) {
|
||||
if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second)
|
||||
continue;
|
||||
outputs.push_back(returned);
|
||||
}
|
||||
if (outputs.size() != byOutput.size())
|
||||
return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs");
|
||||
|
||||
for (Value originalOutput : outputs) {
|
||||
if (isa_and_present<SpatScheduledCompute, SpatScheduledComputeBatch>(originalOutput.getDefiningOp())) {
|
||||
return state.func.emitError(
|
||||
"projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result");
|
||||
}
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return state.func.emitError("projected host output must have static ranked tensor type");
|
||||
|
||||
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
|
||||
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
|
||||
const PendingProjectedHostOutputFragment* rhs) {
|
||||
if (lhs->sourceClass != rhs->sourceClass)
|
||||
return lhs->sourceClass < rhs->sourceClass;
|
||||
if (lhs->publicationResultIndex != rhs->publicationResultIndex)
|
||||
return lhs->publicationResultIndex < rhs->publicationResultIndex;
|
||||
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
|
||||
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
|
||||
return std::lexicographical_compare(lhs->offsets.begin(),
|
||||
lhs->offsets.end(),
|
||||
rhs->offsets.begin(),
|
||||
rhs->offsets.end());
|
||||
});
|
||||
|
||||
state.rewriter.setInsertionPoint(returnOp);
|
||||
Location loc = fragments.front()->loc;
|
||||
SmallVector<Value, 16> blueprintOperands;
|
||||
SmallVector<int64_t, 16> fragmentOperandIndices;
|
||||
SmallVector<int64_t, 16> fragmentSourceOffsets;
|
||||
SmallVector<int64_t, 64> flatOffsets;
|
||||
SmallVector<int64_t, 64> flatSizes;
|
||||
SmallVector<int64_t, 64> flatStrides;
|
||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
if (fragmentRecord->sourceClass >= state.classes.size())
|
||||
return state.func.emitError("projected host output fragment references an invalid source class");
|
||||
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) {
|
||||
return sourceClass.op->emitError("projected host output fragment references an invalid publication result")
|
||||
<< " sourceClass=" << sourceClass.id
|
||||
<< " resultIndex=" << fragmentRecord->publicationResultIndex
|
||||
<< " resultCount=" << sourceClass.op->getNumResults();
|
||||
}
|
||||
|
||||
Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex);
|
||||
auto [operandIt, inserted] =
|
||||
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(blueprintOperands.size()));
|
||||
if (inserted)
|
||||
blueprintOperands.push_back(operand);
|
||||
fragmentOperandIndices.push_back(operandIt->second);
|
||||
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
|
||||
llvm::append_range(flatOffsets, fragmentRecord->offsets);
|
||||
llvm::append_range(flatSizes, fragmentRecord->sizes);
|
||||
llvm::append_range(flatStrides, fragmentRecord->strides);
|
||||
|
||||
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
|
||||
if (!operandType || !operandType.hasStaticShape())
|
||||
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
|
||||
}
|
||||
|
||||
if (blueprintOperands.empty())
|
||||
return state.func.emitError("missing projected host output fragments");
|
||||
|
||||
Value input = blueprintOperands.front();
|
||||
ValueRange extraFragments = ValueRange(blueprintOperands).drop_front();
|
||||
auto blueprint = SpatBlueprintOp::create(
|
||||
state.rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
extraFragments,
|
||||
state.rewriter.getStringAttr("nchw"),
|
||||
state.rewriter.getStringAttr("fragmented"),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatSizes),
|
||||
state.rewriter.getStringAttr("identity"),
|
||||
state.rewriter.getStringAttr("fragment_assembly"),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatStrides),
|
||||
state.rewriter.getStringAttr("disjoint"),
|
||||
state.rewriter.getStringAttr("complete"));
|
||||
|
||||
state.hostReplacements[originalOutput] = blueprint.getOutput();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct MaterializerState;
|
||||
|
||||
mlir::LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state);
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
+125
-729
@@ -24,14 +24,18 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "HostOutputFinalization.hpp"
|
||||
#include "MaterializedClassState.hpp"
|
||||
#include "MergeMessages.hpp"
|
||||
#include "MergeScheduleKeys.hpp"
|
||||
#include "ProjectedFragments.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -39,353 +43,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
namespace {
|
||||
|
||||
using CpuId = size_t;
|
||||
using ClassId = size_t;
|
||||
using SlotId = size_t;
|
||||
|
||||
static FailureOr<int32_t> getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) {
|
||||
return pim::checkedI32(static_cast<uint64_t>(cpu), anchor, fieldName);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t, 8>>
|
||||
getCheckedCoreIds(Operation* anchor, ArrayRef<CpuId> cpus, StringRef fieldName) {
|
||||
SmallVector<int32_t, 8> coreIds;
|
||||
coreIds.reserve(cpus.size());
|
||||
for (CpuId cpu : cpus) {
|
||||
auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName);
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
}
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
struct MessageVector {
|
||||
SmallVector<int64_t, 16> channelIds;
|
||||
SmallVector<int32_t, 16> sourceCoreIds;
|
||||
SmallVector<int32_t, 16> targetCoreIds;
|
||||
|
||||
size_t size() const { return channelIds.size(); }
|
||||
bool empty() const { return channelIds.empty(); }
|
||||
|
||||
LogicalResult verify(Operation* anchor) const {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return anchor->emitError("message metadata is inconsistent");
|
||||
return success();
|
||||
}
|
||||
|
||||
void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) {
|
||||
channelIds.push_back(channelId);
|
||||
sourceCoreIds.push_back(sourceCoreId);
|
||||
targetCoreIds.push_back(targetCoreId);
|
||||
}
|
||||
|
||||
void append(ArrayRef<int64_t> channels, ArrayRef<int32_t> sources, ArrayRef<int32_t> targets) {
|
||||
assert(channels.size() == sources.size() && "channel/source count mismatch");
|
||||
assert(channels.size() == targets.size() && "channel/target count mismatch");
|
||||
llvm::append_range(channelIds, channels);
|
||||
llvm::append_range(sourceCoreIds, sources);
|
||||
llvm::append_range(targetCoreIds, targets);
|
||||
}
|
||||
|
||||
MessageVector slice(size_t offset, size_t count) const {
|
||||
MessageVector result;
|
||||
result.append(ArrayRef<int64_t>(channelIds).slice(offset, count),
|
||||
ArrayRef<int32_t>(sourceCoreIds).slice(offset, count),
|
||||
ArrayRef<int32_t>(targetCoreIds).slice(offset, count));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProducerKey {
|
||||
ComputeInstance instance;
|
||||
size_t resultIndex = 0;
|
||||
|
||||
bool operator==(const ProducerKey& other) const {
|
||||
return instance == other.instance && resultIndex == other.resultIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProducerKeyInfo {
|
||||
static ProducerKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<ComputeInstance>::getEmptyKey(), std::numeric_limits<size_t>::max()};
|
||||
}
|
||||
|
||||
static ProducerKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<ComputeInstance>::getTombstoneKey(), std::numeric_limits<size_t>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProducerKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<ComputeInstance>::getHashValue(key.instance), key.resultIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKey {
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const SameClassConsumerLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKeyInfo {
|
||||
static SameClassConsumerLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(),
|
||||
std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static SameClassConsumerLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(),
|
||||
std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKey {
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKeyInfo {
|
||||
static WholeBatchAssemblyLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(),
|
||||
std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static WholeBatchAssemblyLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(),
|
||||
std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
using ClassSlotKey = std::pair<ClassId, SlotId>;
|
||||
|
||||
struct MaterializedClass {
|
||||
ClassId id = 0;
|
||||
SmallVector<CpuId, 8> cpus;
|
||||
Operation* op = nullptr;
|
||||
Block* body = nullptr;
|
||||
bool isBatch = false;
|
||||
|
||||
DenseMap<CpuId, unsigned> cpuToLane;
|
||||
SmallVector<Value, 8> weights;
|
||||
SmallVector<Value, 8> inputs;
|
||||
SmallVector<Value, 4> hostOutputs;
|
||||
DenseMap<Value, unsigned> publicationOutputToResultIndex;
|
||||
DenseMap<Value, BlockArgument> weightArgs;
|
||||
DenseMap<Value, BlockArgument> inputArgs;
|
||||
DenseMap<Value, unsigned> hostOutputToResultIndex;
|
||||
};
|
||||
|
||||
struct PackedScalarRunSlot {
|
||||
SmallVector<ProducerKey, 8> keys;
|
||||
};
|
||||
|
||||
enum class PackedScalarRunKind {
|
||||
Materialized,
|
||||
DeferredReceive,
|
||||
DeferredLocalCompute
|
||||
};
|
||||
|
||||
struct PackedScalarRunValue {
|
||||
ClassId targetClass = 0;
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
PackedScalarRunKind kind = PackedScalarRunKind::Materialized;
|
||||
|
||||
Value packed;
|
||||
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<PackedScalarRunSlot, 8> slots;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct IndexedBatchRunValue {
|
||||
ClassId targetClass = 0;
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
Value packed;
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<PackedScalarRunSlot, 8> slots;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct LogicalSlotRange {
|
||||
SlotId start = 0;
|
||||
SlotId count = 0;
|
||||
};
|
||||
|
||||
struct MaterializationRunSlot {
|
||||
SmallVector<ComputeInstance, 8> peers;
|
||||
};
|
||||
|
||||
using MaterializationRun = SmallVector<MaterializationRunSlot, 8>;
|
||||
|
||||
struct OutputDestinationGroup {
|
||||
SmallVector<size_t, 4> resultIndices;
|
||||
SmallVector<ClassId, 4> destinationClasses;
|
||||
};
|
||||
|
||||
struct BatchRunSendPlan {
|
||||
size_t resultIndex = 0;
|
||||
ClassId destinationClass = 0;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct ProjectedBatchInputKey {
|
||||
Operation* consumerOp = nullptr;
|
||||
unsigned inputIndex = 0;
|
||||
|
||||
bool operator==(const ProjectedBatchInputKey& other) const {
|
||||
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProjectedBatchInputKeyInfo {
|
||||
static ProjectedBatchInputKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static ProjectedBatchInputKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
|
||||
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
struct ProjectedFragmentLayout {
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
unsigned fragmentsPerLogicalSlot = 1;
|
||||
unsigned payloadFragmentCount = 1;
|
||||
SmallVector<int64_t, 4> loopLowerBounds;
|
||||
SmallVector<int64_t, 4> loopSteps;
|
||||
SmallVector<int64_t, 4> loopTripCounts;
|
||||
};
|
||||
|
||||
struct StaticProjectedLoopInfo {
|
||||
BlockArgument iv;
|
||||
int64_t lowerBound = 0;
|
||||
int64_t step = 1;
|
||||
int64_t tripCount = 1;
|
||||
};
|
||||
|
||||
struct ProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
Operation* extractOp = nullptr;
|
||||
ProjectedFragmentLayout layout;
|
||||
RankedTensorType payloadType;
|
||||
SmallVector<SmallVector<int64_t, 4>, 16> fragmentOffsets;
|
||||
SmallVector<SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim;
|
||||
};
|
||||
|
||||
struct ProjectedExtractReplacement {
|
||||
Value payload;
|
||||
ProjectedFragmentLayout layout;
|
||||
};
|
||||
|
||||
struct PendingProjectedHostOutputFragment {
|
||||
Value originalOutput;
|
||||
ClassId sourceClass = 0;
|
||||
ProducerKey producerKey;
|
||||
unsigned publicationResultIndex = 0;
|
||||
int64_t sourceFragmentOrdinal = 0;
|
||||
int64_t sourceElementOffset = 0;
|
||||
SmallVector<int64_t, 4> offsets;
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
uint32_t sourceLane = 0;
|
||||
Location loc;
|
||||
};
|
||||
|
||||
enum class TensorDemandActionKind {
|
||||
DestinationFanout,
|
||||
SameClassIndexedFragment,
|
||||
TerminalBlueprintPublication,
|
||||
WholeTensorBarrier
|
||||
};
|
||||
|
||||
enum class WholeTensorBarrierReason {
|
||||
FunctionReturnWithoutBlueprint,
|
||||
DenseLogicalConsumer
|
||||
};
|
||||
|
||||
struct TensorDemandAction {
|
||||
TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout;
|
||||
std::optional<ClassId> destinationClass;
|
||||
std::optional<WholeTensorBarrierReason> barrierReason;
|
||||
};
|
||||
|
||||
struct RunOutputDemand {
|
||||
size_t resultIndex = 0;
|
||||
Value originalOutput;
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<TensorDemandAction, 4> actions;
|
||||
};
|
||||
|
||||
struct CompactRunPlan {
|
||||
SmallVector<RunOutputDemand, 4> outputs;
|
||||
};
|
||||
|
||||
enum class BatchInputDemandKind {
|
||||
LaneFragment,
|
||||
ProjectedFragment,
|
||||
WholeTensorBarrier
|
||||
};
|
||||
|
||||
struct BatchInputDemand {
|
||||
BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment;
|
||||
std::optional<ProducerKey> wholeTensorProducer;
|
||||
};
|
||||
|
||||
struct AffineProjectedInputSliceMatch {
|
||||
tensor::ExtractSliceOp extract;
|
||||
RankedTensorType sourceType;
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
SmallVector<OpFoldResult, 4> offsets;
|
||||
SmallVector<StaticProjectedLoopInfo, 4> loops;
|
||||
};
|
||||
|
||||
struct CloneIndexingContext {
|
||||
std::optional<Value> runSlotIndex;
|
||||
std::optional<Value> projectionSlotIndex;
|
||||
};
|
||||
|
||||
struct MaterializerState;
|
||||
FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
@@ -444,111 +101,6 @@ FailureOr<Value> materializeProjectedWholeBatchExtractReplacement(MaterializerSt
|
||||
tensor::ExtractSliceOp extract,
|
||||
ProducerKey producer,
|
||||
IRMapping* mapper = nullptr);
|
||||
|
||||
class AvailableValueStore {
|
||||
public:
|
||||
struct ExactBatchFragmentRecord {
|
||||
ProducerKey key;
|
||||
Value value;
|
||||
};
|
||||
|
||||
void record(ProducerKey key, ClassId classId, Value value) {
|
||||
exactValues[key][classId] = value;
|
||||
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||
if (!batch || key.instance.laneCount == 0)
|
||||
return;
|
||||
|
||||
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
|
||||
SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
|
||||
for (ExactBatchFragmentRecord& record : bucket) {
|
||||
if (!(record.key == key))
|
||||
continue;
|
||||
record.value = value;
|
||||
return;
|
||||
}
|
||||
bucket.push_back({key, value});
|
||||
}
|
||||
|
||||
void recordPackedRun(PackedScalarRunValue run) {
|
||||
size_t runIndex = packedScalarRuns.size();
|
||||
packedScalarRuns.push_back(std::move(run));
|
||||
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
|
||||
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
|
||||
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
|
||||
}
|
||||
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
|
||||
|
||||
std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const;
|
||||
|
||||
std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
|
||||
|
||||
ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = packedRunsByProducerResultClass.find(key);
|
||||
if (it == packedRunsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = exactBatchFragmentsByProducerResultClass.find(key);
|
||||
if (it == exactBatchFragmentsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
|
||||
|
||||
private:
|
||||
std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues;
|
||||
SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
|
||||
SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
|
||||
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<ExactBatchFragmentRecord, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||
exactBatchFragmentsByProducerResultClass;
|
||||
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||
packedRunsByProducerResultClass;
|
||||
};
|
||||
|
||||
struct MaterializerState {
|
||||
func::FuncOp func;
|
||||
const MergeScheduleResult& schedule;
|
||||
IRRewriter rewriter;
|
||||
OperationFolder constantFolder;
|
||||
int64_t& nextChannelId;
|
||||
SmallVector<MaterializedClass, 8> classes;
|
||||
DenseMap<CpuId, ClassId> cpuToClass;
|
||||
DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
|
||||
DenseMap<ComputeInstance, LogicalSlotRange> scheduledInstanceToLogicalSlots;
|
||||
DenseMap<ComputeInstance, ComputeInstance> logicalInstanceToScheduledChunk;
|
||||
DenseSet<ClassSlotKey> materializedLogicalSlots;
|
||||
|
||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||
DenseMap<SameClassConsumerLookupKey, SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
|
||||
sameClassConsumerIndex;
|
||||
DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches;
|
||||
DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
|
||||
DenseMap<Value, bool> liveExternalUseCache;
|
||||
DenseMap<Operation*, SmallVector<Type, 4>> batchOutputFragmentTypesCache;
|
||||
DenseMap<ComputeInstance, SmallVector<Value, 4>, llvm::DenseMapInfo<ComputeInstance>> computeInstanceOutputsCache;
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
|
||||
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
|
||||
AvailableValueStore availableValues;
|
||||
DenseMap<Value, Value> hostReplacements;
|
||||
DenseMap<Value, ClassId> hostOutputOwners;
|
||||
SmallVector<PendingProjectedHostOutputFragment, 32> pendingProjectedHostOutputFragments;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
|
||||
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
||||
: func(func),
|
||||
schedule(schedule),
|
||||
rewriter(func.getContext()),
|
||||
constantFolder(func.getContext()),
|
||||
nextChannelId(nextChannelId) {}
|
||||
};
|
||||
|
||||
bool isConstantLike(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
@@ -1260,17 +812,6 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
|
||||
Block& funcBlock = state.func.getBody().front();
|
||||
for (Operation& op : funcBlock) {
|
||||
if (state.oldComputeOps.contains(&op)) {
|
||||
state.rewriter.setInsertionPoint(&op);
|
||||
return;
|
||||
}
|
||||
}
|
||||
state.rewriter.setInsertionPointToEnd(&funcBlock);
|
||||
}
|
||||
|
||||
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
|
||||
auto it = materializedClass.weightArgs.find(weight);
|
||||
if (it != materializedClass.weightArgs.end())
|
||||
@@ -1411,19 +952,6 @@ FailureOr<unsigned> appendBatchPublicationResult(MaterializerState& state,
|
||||
// Materialized-class value localization helpers.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
Operation* getEnclosingSpatialComputeLikeOp(Value value) {
|
||||
Block* block = nullptr;
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
@@ -2169,14 +1697,6 @@ buildProjectedFragmentOffsetsInClass(MaterializerState& state,
|
||||
return fragmentOffsets;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
Value createDim0InsertSlice(
|
||||
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
|
||||
auto fragmentType = cast<RankedTensorType>(fragment.getType());
|
||||
@@ -2293,6 +1813,8 @@ std::optional<Value> extractPackedProducerSlice(MaterializerState& state,
|
||||
return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<Value> AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const {
|
||||
auto producerIt = exactValues.find(key);
|
||||
if (producerIt == exactValues.end())
|
||||
@@ -2305,6 +1827,32 @@ std::optional<Value> AvailableValueStore::lookupExact(ProducerKey key, ClassId c
|
||||
return valueIt->second;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using IndexedFragmentBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
|
||||
using IndexedInsertOffsetBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
|
||||
|
||||
SmallVector<ProducerKey, 16> flattenPackedScalarRunKeys(const PackedScalarRunValue& run);
|
||||
FailureOr<Value> emitIndexedFragmentInsertLoop(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
Value destination,
|
||||
int64_t itemCount,
|
||||
IndexedFragmentBuilder buildFragment,
|
||||
IndexedInsertOffsetBuilder buildOffset,
|
||||
Location loc);
|
||||
FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
Value laneValue,
|
||||
ArrayRef<size_t> resultIndices,
|
||||
CloneIndexingContext indexing);
|
||||
Value createIndexedChannelId(
|
||||
MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc);
|
||||
Value createIndexedSourceCoreId(
|
||||
MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc);
|
||||
Value createIndexedTargetCoreId(
|
||||
MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc);
|
||||
|
||||
Value getPackedSliceForRunIndex(MaterializerState& state,
|
||||
Operation* anchor,
|
||||
Value packed,
|
||||
@@ -2322,21 +1870,9 @@ Value getPackedSliceForDynamicRunIndex(
|
||||
return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0));
|
||||
}
|
||||
|
||||
FailureOr<Value> createReceiveConcatLoop(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
RankedTensorType concatType,
|
||||
RankedTensorType fragmentType,
|
||||
const MessageVector& messages,
|
||||
Location loc);
|
||||
|
||||
using IndexedFragmentBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
|
||||
using IndexedInsertOffsetBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
|
||||
|
||||
FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
PackedScalarRunValue& run,
|
||||
Location loc);
|
||||
|
||||
bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) {
|
||||
return run.kind == PackedScalarRunKind::DeferredLocalCompute;
|
||||
}
|
||||
@@ -2376,8 +1912,69 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
|
||||
if (run.kind == PackedScalarRunKind::Materialized)
|
||||
return targetClass.op->emitError("materialized packed scalar run has no packed value");
|
||||
|
||||
if (isDeferredLocalPackedScalarRun(run))
|
||||
return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc);
|
||||
if (isDeferredLocalPackedScalarRun(run)) {
|
||||
SmallVector<ProducerKey, 16> keys = flattenPackedScalarRunKeys(run);
|
||||
if (keys.empty())
|
||||
return failure();
|
||||
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(run.fragmentType, keys.size());
|
||||
if (failed(packedType))
|
||||
return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor");
|
||||
|
||||
SmallVector<int64_t, 16> sourceLanes;
|
||||
sourceLanes.reserve(keys.size());
|
||||
for (ProducerKey key : keys) {
|
||||
if (key.instance.laneCount != 1)
|
||||
return failure();
|
||||
sourceLanes.push_back(key.instance.laneStart);
|
||||
}
|
||||
|
||||
SmallVector<size_t, 1> resultIndices {run.resultIndex};
|
||||
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
Value init =
|
||||
tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult();
|
||||
|
||||
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
|
||||
Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(keys.size()));
|
||||
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
|
||||
|
||||
auto loop = buildNormalizedScfFor(
|
||||
state.rewriter,
|
||||
loc,
|
||||
lowerBound,
|
||||
upperBound,
|
||||
step,
|
||||
ValueRange {init},
|
||||
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value acc = iterArgs.front();
|
||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state,
|
||||
targetClass,
|
||||
keys.front().instance,
|
||||
sourceLane,
|
||||
resultIndices,
|
||||
CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex});
|
||||
if (failed(produced) || produced->size() != 1)
|
||||
return failure();
|
||||
|
||||
FailureOr<Value> firstOffset =
|
||||
scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc);
|
||||
if (failed(firstOffset))
|
||||
return failure();
|
||||
FailureOr<Value> next =
|
||||
createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset);
|
||||
if (failed(next))
|
||||
return failure();
|
||||
yielded.push_back(*next);
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
run.packed = loop->results.front();
|
||||
return run.packed;
|
||||
}
|
||||
|
||||
if (failed(validatePackedScalarRunMetadata(targetClass.op, run)))
|
||||
return failure();
|
||||
@@ -2387,13 +1984,34 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
|
||||
if (failed(fullPackedType))
|
||||
return targetClass.op->emitError("cannot create lazy packed scalar run receive type");
|
||||
|
||||
auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc);
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
Value init =
|
||||
tensor::EmptyOp::create(state.rewriter, loc, fullPackedType->getShape(), fullPackedType->getElementType())
|
||||
.getResult();
|
||||
auto packed = emitIndexedFragmentInsertLoop(
|
||||
state,
|
||||
targetClass,
|
||||
init,
|
||||
static_cast<int64_t>(run.messages.size()),
|
||||
[&](Value index) -> FailureOr<Value> {
|
||||
Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, index, loc);
|
||||
Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, index, loc);
|
||||
Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, index, loc);
|
||||
return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId)
|
||||
.getOutput();
|
||||
},
|
||||
[&](Value index) -> FailureOr<Value> {
|
||||
return scaleIndexByDim0SizeInClass(state, targetClass, index, run.fragmentType.getDimSize(0), loc);
|
||||
},
|
||||
loc);
|
||||
if (failed(packed))
|
||||
return failure();
|
||||
run.packed = *packed;
|
||||
return run.packed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) {
|
||||
for (PackedScalarRunValue& run : packedScalarRuns) {
|
||||
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
@@ -2491,6 +2109,8 @@ std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, Produ
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||
SmallVector<APInt, 8> elements;
|
||||
elements.reserve(values.size());
|
||||
@@ -2983,89 +2603,6 @@ isProjectedOffsetValue(Value value, Value laneArg, ArrayRef<StaticProjectedLoopI
|
||||
|
||||
static std::optional<int64_t> getConstantIndex(OpFoldResult value);
|
||||
|
||||
static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef<int64_t> loopTripCounts) {
|
||||
unsigned fragmentsPerLogicalSlot = 1;
|
||||
for (int64_t tripCount : loopTripCounts) {
|
||||
assert(tripCount > 0 && "projected loop trip counts must be positive");
|
||||
fragmentsPerLogicalSlot *= static_cast<unsigned>(tripCount);
|
||||
}
|
||||
return fragmentsPerLogicalSlot;
|
||||
}
|
||||
|
||||
LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) {
|
||||
if (!layout.fragmentType || layout.fragmentShape.empty())
|
||||
return anchor->emitError("projected fragment layout is missing fragment type metadata");
|
||||
if (layout.fragmentShape.size() != static_cast<size_t>(layout.fragmentType.getRank()))
|
||||
return anchor->emitError("projected fragment layout rank does not match fragment type");
|
||||
if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0)
|
||||
return anchor->emitError("projected fragment layout has an invalid fragment count");
|
||||
if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0)
|
||||
return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots");
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<RankedTensorType>
|
||||
getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) {
|
||||
if (failed(
|
||||
verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type")))
|
||||
return failure();
|
||||
return getPackedBatchTensorType(fragmentType, payloadFragmentCount);
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<int64_t, 16>, 4>
|
||||
buildProjectedFragmentOffsetsByDim(ArrayRef<SmallVector<int64_t, 4>> fragmentOffsets, size_t rank) {
|
||||
SmallVector<SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim(rank);
|
||||
for (ArrayRef<int64_t> offsets : fragmentOffsets) {
|
||||
assert(offsets.size() == rank && "projected offset rank mismatch");
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
fragmentOffsetsByDim[dim].push_back(offsets[dim]);
|
||||
}
|
||||
return fragmentOffsetsByDim;
|
||||
}
|
||||
|
||||
LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) {
|
||||
if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout)))
|
||||
return failure();
|
||||
if (!descriptor.payloadType)
|
||||
return anchor->emitError("projected transfer descriptor is missing payload type");
|
||||
if (descriptor.fragmentOffsets.empty())
|
||||
return anchor->emitError("projected transfer descriptor expected at least one fragment offset");
|
||||
if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size())
|
||||
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
|
||||
for (ArrayRef<int64_t> dimOffsets : descriptor.fragmentOffsetsByDim)
|
||||
if (dimOffsets.size() != descriptor.fragmentOffsets.size())
|
||||
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
|
||||
for (ArrayRef<int64_t> offsets : descriptor.fragmentOffsets)
|
||||
if (offsets.size() != descriptor.layout.fragmentShape.size())
|
||||
return anchor->emitError("projected transfer offset rank does not match fragment rank");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyProjectedSendDescriptor(Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
const MessageVector& messages) {
|
||||
if (failed(verifyProjectedTransferDescriptor(anchor, descriptor)))
|
||||
return failure();
|
||||
if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size())
|
||||
return anchor->emitError("projected send descriptor metadata is inconsistent");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) {
|
||||
descriptor.fragmentOffsetsByDim =
|
||||
buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size());
|
||||
|
||||
FailureOr<RankedTensorType> payloadType =
|
||||
getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount);
|
||||
if (failed(payloadType))
|
||||
return failure();
|
||||
if (descriptor.payloadType && descriptor.payloadType != *payloadType)
|
||||
return anchor->emitError("projected transfer descriptor payload type does not match projected layout");
|
||||
descriptor.payloadType = *payloadType;
|
||||
|
||||
return verifyProjectedTransferDescriptor(anchor, descriptor);
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> evaluateProjectedOffsetValue(OpFoldResult value,
|
||||
Value laneArg,
|
||||
uint32_t lane,
|
||||
@@ -4819,12 +4356,12 @@ FailureOr<Value> materializeIndexedBatchRunReceive(MaterializerState& state,
|
||||
Value runSlotIndex,
|
||||
Location loc);
|
||||
|
||||
} // namespace
|
||||
|
||||
FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
PackedScalarRunValue& run,
|
||||
Location loc) {
|
||||
assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run");
|
||||
|
||||
SmallVector<ProducerKey, 16> keys = flattenPackedScalarRunKeys(run);
|
||||
if (keys.empty())
|
||||
return failure();
|
||||
@@ -4888,6 +4425,8 @@ FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState&
|
||||
return run.packed;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
ProducerKey key,
|
||||
@@ -5946,119 +5485,6 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
if (state.pendingProjectedHostOutputFragments.empty())
|
||||
return success();
|
||||
|
||||
DenseMap<Value, SmallVector<PendingProjectedHostOutputFragment*, 16>> byOutput;
|
||||
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments)
|
||||
byOutput[fragment.originalOutput].push_back(&fragment);
|
||||
|
||||
SmallVector<Value, 8> outputs;
|
||||
outputs.reserve(byOutput.size());
|
||||
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
|
||||
if (!returnOp)
|
||||
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
|
||||
|
||||
DenseSet<Value> seenOutputs;
|
||||
for (Value returned : returnOp.getOperands()) {
|
||||
if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second)
|
||||
continue;
|
||||
outputs.push_back(returned);
|
||||
}
|
||||
if (outputs.size() != byOutput.size())
|
||||
return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs");
|
||||
|
||||
for (Value originalOutput : outputs) {
|
||||
if (isa_and_present<SpatScheduledCompute, SpatScheduledComputeBatch>(originalOutput.getDefiningOp())) {
|
||||
return state.func.emitError("projected host output assembly must be keyed by the original logical host output, "
|
||||
"not by a materialized scheduled result");
|
||||
}
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return state.func.emitError("projected host output must have static ranked tensor type");
|
||||
|
||||
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
|
||||
llvm::sort(fragments,
|
||||
[](const PendingProjectedHostOutputFragment* lhs, const PendingProjectedHostOutputFragment* rhs) {
|
||||
if (lhs->sourceClass != rhs->sourceClass)
|
||||
return lhs->sourceClass < rhs->sourceClass;
|
||||
if (lhs->publicationResultIndex != rhs->publicationResultIndex)
|
||||
return lhs->publicationResultIndex < rhs->publicationResultIndex;
|
||||
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
|
||||
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
|
||||
return std::lexicographical_compare(
|
||||
lhs->offsets.begin(), lhs->offsets.end(), rhs->offsets.begin(), rhs->offsets.end());
|
||||
});
|
||||
|
||||
state.rewriter.setInsertionPoint(returnOp);
|
||||
Location loc = fragments.front()->loc;
|
||||
SmallVector<Value, 16> blueprintOperands;
|
||||
SmallVector<int64_t, 16> fragmentOperandIndices;
|
||||
SmallVector<int64_t, 16> fragmentSourceOffsets;
|
||||
SmallVector<int64_t, 64> flatOffsets;
|
||||
SmallVector<int64_t, 64> flatSizes;
|
||||
SmallVector<int64_t, 64> flatStrides;
|
||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
if (fragmentRecord->sourceClass >= state.classes.size())
|
||||
return state.func.emitError("projected host output fragment references an invalid source class");
|
||||
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) {
|
||||
return sourceClass.op->emitError("projected host output fragment references an invalid publication result")
|
||||
<< " sourceClass=" << sourceClass.id << " resultIndex=" << fragmentRecord->publicationResultIndex
|
||||
<< " resultCount=" << sourceClass.op->getNumResults();
|
||||
}
|
||||
|
||||
Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex);
|
||||
|
||||
auto [operandIt, inserted] =
|
||||
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(blueprintOperands.size()));
|
||||
if (inserted)
|
||||
blueprintOperands.push_back(operand);
|
||||
fragmentOperandIndices.push_back(operandIt->second);
|
||||
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
|
||||
llvm::append_range(flatOffsets, fragmentRecord->offsets);
|
||||
llvm::append_range(flatSizes, fragmentRecord->sizes);
|
||||
llvm::append_range(flatStrides, fragmentRecord->strides);
|
||||
|
||||
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
|
||||
if (!operandType || !operandType.hasStaticShape())
|
||||
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
|
||||
}
|
||||
|
||||
if (blueprintOperands.empty())
|
||||
return state.func.emitError("missing projected host output fragments");
|
||||
|
||||
Value input = blueprintOperands.front();
|
||||
ValueRange extraFragments = ValueRange(blueprintOperands).drop_front();
|
||||
auto blueprint = spatial::SpatBlueprintOp::create(state.rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
extraFragments,
|
||||
state.rewriter.getStringAttr("nchw"),
|
||||
state.rewriter.getStringAttr("fragmented"),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatSizes),
|
||||
state.rewriter.getStringAttr("identity"),
|
||||
state.rewriter.getStringAttr("fragment_assembly"),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatStrides),
|
||||
state.rewriter.getStringAttr("disjoint"),
|
||||
state.rewriter.getStringAttr("complete"));
|
||||
|
||||
state.hostReplacements[originalOutput] = blueprint.getOutput();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<Value> resolveInputValue(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
Value input,
|
||||
@@ -8191,36 +7617,6 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<Value> createReceiveConcatLoop(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
RankedTensorType concatType,
|
||||
RankedTensorType fragmentType,
|
||||
const MessageVector& messages,
|
||||
Location loc) {
|
||||
assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent");
|
||||
assert(!messages.empty() && "expected at least one receive");
|
||||
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
Value init =
|
||||
tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult();
|
||||
return emitIndexedFragmentInsertLoop(
|
||||
state,
|
||||
targetClass,
|
||||
init,
|
||||
static_cast<int64_t>(messages.size()),
|
||||
[&](Value index) -> FailureOr<Value> {
|
||||
Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc);
|
||||
Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc);
|
||||
Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc);
|
||||
return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId)
|
||||
.getOutput();
|
||||
},
|
||||
[&](Value index) -> FailureOr<Value> {
|
||||
return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc);
|
||||
},
|
||||
loc);
|
||||
}
|
||||
|
||||
bool valueMayEvaluateToCore(Value value, int64_t coreId) {
|
||||
if (std::optional<int64_t> constant = getConstantIndexValue(value))
|
||||
return *constant == coreId;
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "MergeMessages.hpp"
|
||||
#include "MergeScheduleKeys.hpp"
|
||||
#include "ProjectedFragments.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct MaterializedClass {
|
||||
ClassId id = 0;
|
||||
llvm::SmallVector<CpuId, 8> cpus;
|
||||
mlir::Operation* op = nullptr;
|
||||
mlir::Block* body = nullptr;
|
||||
bool isBatch = false;
|
||||
|
||||
llvm::DenseMap<CpuId, unsigned> cpuToLane;
|
||||
llvm::SmallVector<mlir::Value, 8> weights;
|
||||
llvm::SmallVector<mlir::Value, 8> inputs;
|
||||
llvm::SmallVector<mlir::Value, 4> hostOutputs;
|
||||
llvm::DenseMap<mlir::Value, unsigned> publicationOutputToResultIndex;
|
||||
llvm::DenseMap<mlir::Value, mlir::BlockArgument> weightArgs;
|
||||
llvm::DenseMap<mlir::Value, mlir::BlockArgument> inputArgs;
|
||||
llvm::DenseMap<mlir::Value, unsigned> hostOutputToResultIndex;
|
||||
};
|
||||
|
||||
struct PackedScalarRunSlot {
|
||||
llvm::SmallVector<ProducerKey, 8> keys;
|
||||
};
|
||||
|
||||
enum class PackedScalarRunKind {
|
||||
Materialized,
|
||||
DeferredReceive,
|
||||
DeferredLocalCompute
|
||||
};
|
||||
|
||||
struct PackedScalarRunValue {
|
||||
ClassId targetClass = 0;
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
PackedScalarRunKind kind = PackedScalarRunKind::Materialized;
|
||||
|
||||
mlir::Value packed;
|
||||
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<PackedScalarRunSlot, 8> slots;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct IndexedBatchRunValue {
|
||||
ClassId targetClass = 0;
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
mlir::Value packed;
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<PackedScalarRunSlot, 8> slots;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
struct LogicalSlotRange {
|
||||
SlotId start = 0;
|
||||
SlotId count = 0;
|
||||
};
|
||||
|
||||
struct MaterializationRunSlot {
|
||||
llvm::SmallVector<ComputeInstance, 8> peers;
|
||||
};
|
||||
|
||||
using MaterializationRun = llvm::SmallVector<MaterializationRunSlot, 8>;
|
||||
|
||||
struct OutputDestinationGroup {
|
||||
llvm::SmallVector<size_t, 4> resultIndices;
|
||||
llvm::SmallVector<ClassId, 4> destinationClasses;
|
||||
};
|
||||
|
||||
struct BatchRunSendPlan {
|
||||
size_t resultIndex = 0;
|
||||
ClassId destinationClass = 0;
|
||||
MessageVector messages;
|
||||
};
|
||||
|
||||
enum class TensorDemandActionKind {
|
||||
DestinationFanout,
|
||||
SameClassIndexedFragment,
|
||||
TerminalBlueprintPublication,
|
||||
WholeTensorBarrier
|
||||
};
|
||||
|
||||
enum class WholeTensorBarrierReason {
|
||||
FunctionReturnWithoutBlueprint,
|
||||
DenseLogicalConsumer
|
||||
};
|
||||
|
||||
struct TensorDemandAction {
|
||||
TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout;
|
||||
std::optional<ClassId> destinationClass;
|
||||
std::optional<WholeTensorBarrierReason> barrierReason;
|
||||
};
|
||||
|
||||
struct RunOutputDemand {
|
||||
size_t resultIndex = 0;
|
||||
mlir::Value originalOutput;
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<TensorDemandAction, 4> actions;
|
||||
};
|
||||
|
||||
struct CompactRunPlan {
|
||||
llvm::SmallVector<RunOutputDemand, 4> outputs;
|
||||
};
|
||||
|
||||
enum class BatchInputDemandKind {
|
||||
LaneFragment,
|
||||
ProjectedFragment,
|
||||
WholeTensorBarrier
|
||||
};
|
||||
|
||||
struct BatchInputDemand {
|
||||
BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment;
|
||||
std::optional<ProducerKey> wholeTensorProducer;
|
||||
};
|
||||
|
||||
struct CloneIndexingContext {
|
||||
std::optional<mlir::Value> runSlotIndex;
|
||||
std::optional<mlir::Value> projectionSlotIndex;
|
||||
};
|
||||
|
||||
struct MaterializerState;
|
||||
|
||||
class AvailableValueStore {
|
||||
public:
|
||||
struct ExactBatchFragmentRecord {
|
||||
ProducerKey key;
|
||||
mlir::Value value;
|
||||
};
|
||||
|
||||
void record(ProducerKey key, ClassId classId, mlir::Value value) {
|
||||
exactValues[key][classId] = value;
|
||||
|
||||
auto batch = mlir::dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||
if (!batch || key.instance.laneCount == 0)
|
||||
return;
|
||||
|
||||
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
|
||||
llvm::SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
|
||||
for (ExactBatchFragmentRecord& record : bucket) {
|
||||
if (!(record.key == key))
|
||||
continue;
|
||||
record.value = value;
|
||||
return;
|
||||
}
|
||||
bucket.push_back({key, value});
|
||||
}
|
||||
|
||||
void recordPackedRun(PackedScalarRunValue run) {
|
||||
size_t runIndex = packedScalarRuns.size();
|
||||
packedScalarRuns.push_back(std::move(run));
|
||||
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
|
||||
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
|
||||
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
|
||||
}
|
||||
|
||||
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
|
||||
|
||||
std::optional<mlir::Value> lookupExact(ProducerKey key, ClassId classId) const;
|
||||
std::optional<mlir::Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
|
||||
|
||||
llvm::ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = packedRunsByProducerResultClass.find(key);
|
||||
if (it == packedRunsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
llvm::ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = exactBatchFragmentsByProducerResultClass.find(key);
|
||||
if (it == exactBatchFragmentsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
|
||||
|
||||
private:
|
||||
std::optional<mlir::Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
|
||||
llvm::DenseMap<ProducerKey, llvm::DenseMap<ClassId, mlir::Value>, ProducerKeyInfo> exactValues;
|
||||
llvm::SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
|
||||
llvm::SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
|
||||
llvm::DenseMap<WholeBatchAssemblyLookupKey,
|
||||
llvm::SmallVector<ExactBatchFragmentRecord, 16>,
|
||||
WholeBatchAssemblyLookupKeyInfo>
|
||||
exactBatchFragmentsByProducerResultClass;
|
||||
llvm::DenseMap<WholeBatchAssemblyLookupKey, llvm::SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||
packedRunsByProducerResultClass;
|
||||
};
|
||||
|
||||
struct MaterializerState {
|
||||
mlir::func::FuncOp func;
|
||||
const MergeScheduleResult& schedule;
|
||||
mlir::IRRewriter rewriter;
|
||||
mlir::OperationFolder constantFolder;
|
||||
int64_t& nextChannelId;
|
||||
llvm::SmallVector<MaterializedClass, 8> classes;
|
||||
llvm::DenseMap<CpuId, ClassId> cpuToClass;
|
||||
llvm::DenseMap<CpuId, llvm::SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
|
||||
llvm::DenseMap<ComputeInstance, LogicalSlotRange> scheduledInstanceToLogicalSlots;
|
||||
llvm::DenseMap<ComputeInstance, ComputeInstance> logicalInstanceToScheduledChunk;
|
||||
llvm::DenseSet<ClassSlotKey> materializedLogicalSlots;
|
||||
|
||||
llvm::DenseMap<ProducerKey, llvm::SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||
llvm::DenseMap<SameClassConsumerLookupKey, llvm::SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
|
||||
sameClassConsumerIndex;
|
||||
llvm::DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo>
|
||||
projectedInputMatches;
|
||||
llvm::DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
|
||||
llvm::DenseMap<mlir::Value, bool> liveExternalUseCache;
|
||||
llvm::DenseMap<mlir::Operation*, llvm::SmallVector<mlir::Type, 4>> batchOutputFragmentTypesCache;
|
||||
llvm::DenseMap<ComputeInstance, llvm::SmallVector<mlir::Value, 4>, llvm::DenseMapInfo<ComputeInstance>>
|
||||
computeInstanceOutputsCache;
|
||||
llvm::DenseMap<ProducerKey, llvm::DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo>
|
||||
projectedTransfers;
|
||||
llvm::DenseMap<mlir::Operation*, llvm::DenseMap<ClassId, ProjectedExtractReplacement>>
|
||||
projectedExtractReplacements;
|
||||
AvailableValueStore availableValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> hostReplacements;
|
||||
llvm::DenseMap<mlir::Value, ClassId> hostOutputOwners;
|
||||
llvm::SmallVector<PendingProjectedHostOutputFragment, 32> pendingProjectedHostOutputFragments;
|
||||
llvm::DenseSet<mlir::Operation*> oldComputeOps;
|
||||
|
||||
MaterializerState(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
||||
: func(func),
|
||||
schedule(schedule),
|
||||
rewriter(func.getContext()),
|
||||
constantFolder(func.getContext()),
|
||||
nextChannelId(nextChannelId) {}
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -28,6 +28,7 @@
|
||||
#include "Scheduling/ComputeGraph.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
@@ -43,16 +44,6 @@ using namespace onnx_mlir::compact_asm;
|
||||
using SpatCompute = spatial::SpatGraphCompute;
|
||||
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return std::nullopt;
|
||||
return *checkedCoreId;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||
if (!compute->hasOneUse())
|
||||
return false;
|
||||
@@ -213,8 +204,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreId = getComputeCoreId(spatCompute))
|
||||
coreIds.push_back(*coreId);
|
||||
auto coreId = getOptionalScheduledCoreId(spatCompute, "merge compute core id");
|
||||
if (failed(coreId))
|
||||
return;
|
||||
if (*coreId)
|
||||
coreIds.push_back(**coreId);
|
||||
uint64_t computeId = totalComputeOps++;
|
||||
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
uint64_t maxConcatOperands = 0;
|
||||
@@ -234,8 +228,11 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
||||
auto optionalCoreIds = getOptionalScheduledBatchCoreIds(batch, "merge compute_batch core id");
|
||||
if (failed(optionalCoreIds))
|
||||
return;
|
||||
if (*optionalCoreIds)
|
||||
coreIds = std::move(**optionalCoreIds);
|
||||
collectedData.push_back(
|
||||
{nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
|
||||
totalComputeOps += 1;
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
using CpuId = size_t;
|
||||
|
||||
inline mlir::FailureOr<int32_t> getCheckedCoreId(mlir::Operation* anchor, CpuId cpu, llvm::StringRef fieldName) {
|
||||
return pim::checkedI32(static_cast<uint64_t>(cpu), anchor, fieldName);
|
||||
}
|
||||
|
||||
inline mlir::FailureOr<llvm::SmallVector<int32_t, 8>>
|
||||
getCheckedCoreIds(mlir::Operation* anchor, llvm::ArrayRef<CpuId> cpus, llvm::StringRef fieldName) {
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
coreIds.reserve(cpus.size());
|
||||
for (CpuId cpu : cpus) {
|
||||
auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName);
|
||||
if (mlir::failed(checkedCoreId))
|
||||
return mlir::failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
}
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
struct MessageVector {
|
||||
llvm::SmallVector<int64_t, 16> channelIds;
|
||||
llvm::SmallVector<int32_t, 16> sourceCoreIds;
|
||||
llvm::SmallVector<int32_t, 16> targetCoreIds;
|
||||
|
||||
size_t size() const { return channelIds.size(); }
|
||||
bool empty() const { return channelIds.empty(); }
|
||||
|
||||
mlir::LogicalResult verify(mlir::Operation* anchor) const {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return anchor->emitError("message metadata is inconsistent");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) {
|
||||
channelIds.push_back(channelId);
|
||||
sourceCoreIds.push_back(sourceCoreId);
|
||||
targetCoreIds.push_back(targetCoreId);
|
||||
}
|
||||
|
||||
void append(llvm::ArrayRef<int64_t> channels, llvm::ArrayRef<int32_t> sources, llvm::ArrayRef<int32_t> targets) {
|
||||
assert(channels.size() == sources.size() && "channel/source count mismatch");
|
||||
assert(channels.size() == targets.size() && "channel/target count mismatch");
|
||||
llvm::append_range(channelIds, channels);
|
||||
llvm::append_range(sourceCoreIds, sources);
|
||||
llvm::append_range(targetCoreIds, targets);
|
||||
}
|
||||
|
||||
MessageVector slice(size_t offset, size_t count) const {
|
||||
MessageVector result;
|
||||
result.append(llvm::ArrayRef<int64_t>(channelIds).slice(offset, count),
|
||||
llvm::ArrayRef<int32_t>(sourceCoreIds).slice(offset, count),
|
||||
llvm::ArrayRef<int32_t>(targetCoreIds).slice(offset, count));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,134 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
using ClassId = size_t;
|
||||
using SlotId = size_t;
|
||||
|
||||
struct ProducerKey {
|
||||
ComputeInstance instance;
|
||||
size_t resultIndex = 0;
|
||||
|
||||
bool operator==(const ProducerKey& other) const {
|
||||
return instance == other.instance && resultIndex == other.resultIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProducerKeyInfo {
|
||||
static ProducerKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<ComputeInstance>::getEmptyKey(), std::numeric_limits<size_t>::max()};
|
||||
}
|
||||
|
||||
static ProducerKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<ComputeInstance>::getTombstoneKey(), std::numeric_limits<size_t>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProducerKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<ComputeInstance>::getHashValue(key.instance), key.resultIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKey {
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const SameClassConsumerLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKeyInfo {
|
||||
static SameClassConsumerLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static SameClassConsumerLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<mlir::Operation*>::getHashValue(key.sourceOp),
|
||||
key.resultIndex,
|
||||
key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKey {
|
||||
mlir::Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKeyInfo {
|
||||
static WholeBatchAssemblyLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static WholeBatchAssemblyLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<mlir::Operation*>::getHashValue(key.sourceOp),
|
||||
key.resultIndex,
|
||||
key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
using ClassSlotKey = std::pair<ClassId, SlotId>;
|
||||
|
||||
struct ProjectedBatchInputKey {
|
||||
mlir::Operation* consumerOp = nullptr;
|
||||
unsigned inputIndex = 0;
|
||||
|
||||
bool operator==(const ProjectedBatchInputKey& other) const {
|
||||
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProjectedBatchInputKeyInfo {
|
||||
static ProjectedBatchInputKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static ProjectedBatchInputKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<mlir::Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
|
||||
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,104 @@
|
||||
#include "ProjectedFragments.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
static mlir::FailureOr<mlir::RankedTensorType> getPackedBatchTensorType(mlir::Type laneType, size_t laneCount) {
|
||||
auto tensorType = mlir::dyn_cast<mlir::RankedTensorType>(laneType);
|
||||
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> shape(tensorType.getShape());
|
||||
shape[0] *= static_cast<int64_t>(laneCount);
|
||||
return mlir::RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
}
|
||||
|
||||
unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef<int64_t> loopTripCounts) {
|
||||
unsigned fragmentsPerLogicalSlot = 1;
|
||||
for (int64_t tripCount : loopTripCounts) {
|
||||
assert(tripCount > 0 && "projected loop trip counts must be positive");
|
||||
fragmentsPerLogicalSlot *= static_cast<unsigned>(tripCount);
|
||||
}
|
||||
return fragmentsPerLogicalSlot;
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout) {
|
||||
if (!layout.fragmentType || layout.fragmentShape.empty())
|
||||
return anchor->emitError("projected fragment layout is missing fragment type metadata");
|
||||
if (layout.fragmentShape.size() != static_cast<size_t>(layout.fragmentType.getRank()))
|
||||
return anchor->emitError("projected fragment layout rank does not match fragment type");
|
||||
if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0)
|
||||
return anchor->emitError("projected fragment layout has an invalid fragment count");
|
||||
if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0)
|
||||
return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::FailureOr<mlir::RankedTensorType>
|
||||
getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount) {
|
||||
auto packedType = getPackedBatchTensorType(fragmentType, payloadFragmentCount);
|
||||
if (mlir::failed(packedType)) {
|
||||
anchor->emitError("cannot create projected payload type");
|
||||
return mlir::failure();
|
||||
}
|
||||
return *packedType;
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4>
|
||||
buildProjectedFragmentOffsetsByDim(llvm::ArrayRef<llvm::SmallVector<int64_t, 4>> fragmentOffsets, size_t rank) {
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim(rank);
|
||||
for (llvm::ArrayRef<int64_t> offsets : fragmentOffsets) {
|
||||
assert(offsets.size() == rank && "projected offset rank mismatch");
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
fragmentOffsetsByDim[dim].push_back(offsets[dim]);
|
||||
}
|
||||
return fragmentOffsetsByDim;
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor) {
|
||||
if (mlir::failed(verifyProjectedFragmentLayout(anchor, descriptor.layout)))
|
||||
return mlir::failure();
|
||||
if (!descriptor.payloadType)
|
||||
return anchor->emitError("projected transfer descriptor is missing payload type");
|
||||
if (descriptor.fragmentOffsets.empty())
|
||||
return anchor->emitError("projected transfer descriptor expected at least one fragment offset");
|
||||
if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size())
|
||||
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
|
||||
for (llvm::ArrayRef<int64_t> dimOffsets : descriptor.fragmentOffsetsByDim)
|
||||
if (dimOffsets.size() != descriptor.fragmentOffsets.size())
|
||||
return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent");
|
||||
for (llvm::ArrayRef<int64_t> offsets : descriptor.fragmentOffsets)
|
||||
if (offsets.size() != descriptor.layout.fragmentShape.size())
|
||||
return anchor->emitError("projected transfer offset rank does not match fragment rank");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
const MessageVector& messages) {
|
||||
if (mlir::failed(verifyProjectedTransferDescriptor(anchor, descriptor)))
|
||||
return mlir::failure();
|
||||
if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size())
|
||||
return anchor->emitError("projected send descriptor metadata is inconsistent");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
ProjectedTransferDescriptor& descriptor) {
|
||||
descriptor.fragmentOffsetsByDim =
|
||||
buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size());
|
||||
|
||||
auto payloadType =
|
||||
getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount);
|
||||
if (mlir::failed(payloadType))
|
||||
return mlir::failure();
|
||||
if (descriptor.payloadType && descriptor.payloadType != *payloadType)
|
||||
return anchor->emitError("projected transfer descriptor payload type does not match projected layout");
|
||||
descriptor.payloadType = *payloadType;
|
||||
|
||||
return verifyProjectedTransferDescriptor(anchor, descriptor);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -0,0 +1,87 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "MergeMessages.hpp"
|
||||
#include "MergeScheduleKeys.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct ProjectedFragmentLayout {
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<int64_t, 4> fragmentShape;
|
||||
unsigned fragmentsPerLogicalSlot = 1;
|
||||
unsigned payloadFragmentCount = 1;
|
||||
llvm::SmallVector<int64_t, 4> loopLowerBounds;
|
||||
llvm::SmallVector<int64_t, 4> loopSteps;
|
||||
llvm::SmallVector<int64_t, 4> loopTripCounts;
|
||||
};
|
||||
|
||||
struct StaticProjectedLoopInfo {
|
||||
mlir::BlockArgument iv;
|
||||
int64_t lowerBound = 0;
|
||||
int64_t step = 1;
|
||||
int64_t tripCount = 1;
|
||||
};
|
||||
|
||||
struct ProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
mlir::Operation* extractOp = nullptr;
|
||||
ProjectedFragmentLayout layout;
|
||||
mlir::RankedTensorType payloadType;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 16> fragmentOffsets;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4> fragmentOffsetsByDim;
|
||||
};
|
||||
|
||||
struct ProjectedExtractReplacement {
|
||||
mlir::Value payload;
|
||||
ProjectedFragmentLayout layout;
|
||||
};
|
||||
|
||||
struct PendingProjectedHostOutputFragment {
|
||||
mlir::Value originalOutput;
|
||||
ClassId sourceClass = 0;
|
||||
ProducerKey producerKey;
|
||||
unsigned publicationResultIndex = 0;
|
||||
int64_t sourceFragmentOrdinal = 0;
|
||||
int64_t sourceElementOffset = 0;
|
||||
llvm::SmallVector<int64_t, 4> offsets;
|
||||
llvm::SmallVector<int64_t, 4> sizes;
|
||||
llvm::SmallVector<int64_t, 4> strides;
|
||||
uint32_t sourceLane = 0;
|
||||
mlir::Location loc;
|
||||
};
|
||||
|
||||
struct AffineProjectedInputSliceMatch {
|
||||
mlir::tensor::ExtractSliceOp extract;
|
||||
mlir::RankedTensorType sourceType;
|
||||
mlir::RankedTensorType fragmentType;
|
||||
llvm::SmallVector<int64_t, 4> fragmentShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
|
||||
llvm::SmallVector<StaticProjectedLoopInfo, 4> loops;
|
||||
};
|
||||
|
||||
unsigned getProjectedFragmentsPerLogicalSlot(llvm::ArrayRef<int64_t> loopTripCounts);
|
||||
mlir::LogicalResult verifyProjectedFragmentLayout(mlir::Operation* anchor, const ProjectedFragmentLayout& layout);
|
||||
mlir::FailureOr<mlir::RankedTensorType>
|
||||
getProjectedPayloadType(mlir::Operation* anchor, mlir::RankedTensorType fragmentType, unsigned payloadFragmentCount);
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t, 16>, 4>
|
||||
buildProjectedFragmentOffsetsByDim(llvm::ArrayRef<llvm::SmallVector<int64_t, 4>> fragmentOffsets, size_t rank);
|
||||
mlir::LogicalResult verifyProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor);
|
||||
mlir::LogicalResult verifyProjectedSendDescriptor(mlir::Operation* anchor,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
const MessageVector& messages);
|
||||
mlir::LogicalResult finalizeProjectedTransferDescriptor(mlir::Operation* anchor,
|
||||
ProjectedTransferDescriptor& descriptor);
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -12,7 +12,6 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using CPU = int;
|
||||
|
||||
@@ -23,13 +23,6 @@ function(add_pim_unittest test_name)
|
||||
set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest)
|
||||
endfunction()
|
||||
|
||||
add_pim_unittest(LabeledListTest
|
||||
LabeledListTest.cpp
|
||||
|
||||
LINK_LIBS PRIVATE
|
||||
OMPimCommon
|
||||
)
|
||||
|
||||
add_pim_unittest(PimMemoryLivenessPlannerTest
|
||||
PimMemoryLivenessPlannerTest.cpp
|
||||
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
|
||||
using onnx_mlir::LabeledList;
|
||||
using onnx_mlir::LabeledListNode;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestNode : public LabeledListNode<TestNode> {
|
||||
explicit TestNode(int id)
|
||||
: id(id) {}
|
||||
|
||||
int id;
|
||||
};
|
||||
|
||||
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
|
||||
auto expectedIt = expectedOrder.begin();
|
||||
for (auto& node : list) {
|
||||
assert(expectedIt != expectedOrder.end());
|
||||
assert(node.id == *expectedIt);
|
||||
++expectedIt;
|
||||
}
|
||||
assert(expectedIt == expectedOrder.end());
|
||||
}
|
||||
|
||||
void assertStrictlyIncreasingLabels(LabeledList<TestNode>& list) {
|
||||
auto it = list.begin();
|
||||
if (it == list.end())
|
||||
return;
|
||||
|
||||
auto previousLabel = it->getOrderLabel();
|
||||
++it;
|
||||
for (; it != list.end(); ++it) {
|
||||
assert(previousLabel < it->getOrderLabel());
|
||||
previousLabel = it->getOrderLabel();
|
||||
}
|
||||
}
|
||||
|
||||
int testLabeledListBasicMutation() {
|
||||
std::cout << "testLabeledListBasicMutation:" << std::endl;
|
||||
|
||||
LabeledList<TestNode> list;
|
||||
TestNode n1(1);
|
||||
TestNode n2(2);
|
||||
TestNode n3(3);
|
||||
TestNode n4(4);
|
||||
TestNode n5(5);
|
||||
|
||||
assert(list.empty());
|
||||
assert(list.front() == nullptr);
|
||||
assert(list.back() == nullptr);
|
||||
assert(!list.contains(&n1));
|
||||
assert(LabeledList<TestNode>::previous(&n1) == nullptr);
|
||||
assert(LabeledList<TestNode>::next(&n1) == nullptr);
|
||||
|
||||
list.pushBack(&n1);
|
||||
list.pushBack(&n3);
|
||||
list.insertAfter(&n1, &n2);
|
||||
list.pushFront(&n4);
|
||||
list.insertBefore(nullptr, &n5);
|
||||
|
||||
assert(list.size() == 5);
|
||||
assert(list.front() == &n4);
|
||||
assert(list.back() == &n5);
|
||||
assert(list.contains(&n2));
|
||||
assertOrder(list, {4, 1, 2, 3, 5});
|
||||
assert(LabeledList<TestNode>::next(&n4) == &n1);
|
||||
assert(LabeledList<TestNode>::previous(&n1) == &n4);
|
||||
assert(LabeledList<TestNode>::next(&n5) == nullptr);
|
||||
assert(list.comesBefore(&n1, &n3));
|
||||
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
|
||||
|
||||
list.moveBefore(&n5, &n2);
|
||||
assertOrder(list, {4, 1, 5, 2, 3});
|
||||
|
||||
list.moveAfter(&n4, &n3);
|
||||
assertOrder(list, {1, 5, 2, 3, 4});
|
||||
|
||||
list.remove(&n2);
|
||||
assert(!n2.isLinked());
|
||||
assert(!list.contains(&n2));
|
||||
assertOrder(list, {1, 5, 3, 4});
|
||||
|
||||
list.clear();
|
||||
assert(list.empty());
|
||||
assert(list.size() == 0);
|
||||
assert(list.front() == nullptr);
|
||||
assert(list.back() == nullptr);
|
||||
assert(!n1.isLinked());
|
||||
assert(!n3.isLinked());
|
||||
assert(!n4.isLinked());
|
||||
assert(!n5.isLinked());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int testLabeledListRelabelingAndNoopMoves() {
|
||||
std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl;
|
||||
|
||||
constexpr int insertedNodeCount = 80;
|
||||
LabeledList<TestNode> list;
|
||||
TestNode head(0);
|
||||
TestNode tail(999);
|
||||
std::vector<TestNode> insertedNodes;
|
||||
insertedNodes.reserve(insertedNodeCount);
|
||||
for (int i = 0; i < insertedNodeCount; ++i)
|
||||
insertedNodes.emplace_back(i + 1);
|
||||
|
||||
list.pushBack(&head);
|
||||
list.pushBack(&tail);
|
||||
for (auto& node : insertedNodes)
|
||||
list.insertAfter(&head, &node);
|
||||
|
||||
assert(list.size() == insertedNodeCount + 2);
|
||||
assert(list.front() == &head);
|
||||
assert(list.back() == &tail);
|
||||
assert(LabeledList<TestNode>::previous(&head) == nullptr);
|
||||
assert(LabeledList<TestNode>::next(&tail) == nullptr);
|
||||
assertStrictlyIncreasingLabels(list);
|
||||
|
||||
auto* firstInserted = LabeledList<TestNode>::next(&head);
|
||||
auto* secondInserted = LabeledList<TestNode>::next(firstInserted);
|
||||
list.moveBefore(firstInserted, secondInserted);
|
||||
list.moveAfter(&head, nullptr);
|
||||
list.moveAfter(&tail, LabeledList<TestNode>::previous(&tail));
|
||||
|
||||
assert(list.front() == &head);
|
||||
assert(list.back() == &tail);
|
||||
assert(firstInserted == &insertedNodes.back());
|
||||
assert(secondInserted == &insertedNodes[insertedNodeCount - 2]);
|
||||
assertStrictlyIncreasingLabels(list);
|
||||
|
||||
int expectedId = insertedNodeCount;
|
||||
auto it = std::next(list.begin());
|
||||
for (; it != list.end() && &*it != &tail; ++it, --expectedId)
|
||||
assert(it->id == expectedId);
|
||||
assert(expectedId == 0);
|
||||
list.clear();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
(void) argc;
|
||||
(void) argv;
|
||||
|
||||
int failures = 0;
|
||||
failures += testLabeledListBasicMutation();
|
||||
failures += testLabeledListRelabelingAndNoopMoves();
|
||||
if (failures != 0) {
|
||||
std::cerr << failures << " test failures\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
Reference in New Issue
Block a user