190 lines
8.1 KiB
C++
190 lines
8.1 KiB
C++
#include "mlir/IR/ValueRange.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
#include <cassert>
|
|
|
|
#include "Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
FailureOr<IntegerAttr> getTensorSizeInBytesAttr(Builder& builder, Operation* anchor, mlir::Value value) {
|
|
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(cast<ShapedType>(value.getType()), anchor, "tensor byte size");
|
|
if (failed(byteSize))
|
|
return failure();
|
|
return pim::getCheckedI32Attr(builder, anchor, *byteSize, "tensor byte size");
|
|
}
|
|
|
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
|
auto users = value.getUsers();
|
|
|
|
assert(!users.empty());
|
|
|
|
Operation* earliestUser = *users.begin();
|
|
for (auto curUser : users)
|
|
if (curUser->isBeforeInBlock(earliestUser))
|
|
earliestUser = curUser;
|
|
|
|
return earliestUser;
|
|
}
|
|
|
|
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
|
auto operandsAndUses =
|
|
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
|
|
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
|
});
|
|
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
|
}
|
|
|
|
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
|
for (Operation* user : value.getUsers()) {
|
|
if (user->getBlock() != operation->getBlock())
|
|
return true;
|
|
if (operation->isBeforeInBlock(user))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
|
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
|
mlir::Value result = operation->getResult(0);
|
|
auto resultType = result.getType();
|
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
|
|
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
|
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
|
|
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
|
|
});
|
|
auto bestOperand = validOperands.begin();
|
|
|
|
if (bestOperand != validOperands.end())
|
|
return *bestOperand;
|
|
|
|
auto resultShapedType = cast<ShapedType>(resultType);
|
|
rewriter.setInsertionPoint(operation);
|
|
return tensor::EmptyOp::create(
|
|
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
|
}
|
|
|
|
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
|
|
int64_t resultRank,
|
|
size_t operandCount,
|
|
ArrayRef<int64_t> operandIndices,
|
|
ArrayRef<int64_t> sourceOffsets,
|
|
ArrayRef<int64_t> flatOffsets,
|
|
ArrayRef<int64_t> flatSizes,
|
|
ArrayRef<int64_t> flatStrides) {
|
|
if (operandIndices.size() != sourceOffsets.size())
|
|
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
|
|
if (flatOffsets.size() != flatSizes.size())
|
|
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
|
|
if (flatStrides.size() != flatOffsets.size())
|
|
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
|
|
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
|
|
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
|
|
|
|
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
|
|
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
|
|
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
|
if (sourceOffsets[fragmentIndex] < 0)
|
|
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
|
SmallVector<int64_t, 4> indices(shape.size(), 0);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
|
indices[dim] = flatIndex % shape[dim];
|
|
flatIndex /= shape[dim];
|
|
}
|
|
return indices;
|
|
}
|
|
|
|
FailureOr<SmallVector<int64_t, 4>>
|
|
getStaticSliceOffsetsForElementOffset(Operation* anchor,
|
|
ShapedType sourceType,
|
|
ArrayRef<int64_t> fragmentShape,
|
|
int64_t sourceElementOffset,
|
|
StringRef fieldName) {
|
|
if (!sourceType.hasStaticShape())
|
|
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
|
|
if (sourceElementOffset < 0)
|
|
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
|
|
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
|
|
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
|
|
|
|
int64_t sourceElementCount = sourceType.getNumElements();
|
|
int64_t fragmentElementCount = 1;
|
|
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
|
if (fragmentShape[dim] < 0)
|
|
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
|
|
fragmentElementCount *= fragmentShape[dim];
|
|
}
|
|
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
|
|
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
|
|
|
|
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
|
|
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
|
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
|
|
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
|
|
}
|
|
return sliceOffsets;
|
|
}
|
|
|
|
LogicalResult
|
|
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
|
|
ArrayRef<int64_t> baseOffsets,
|
|
ArrayRef<int64_t> sizes,
|
|
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
|
|
int64_t rank = static_cast<int64_t>(sizes.size());
|
|
int64_t suffixStart = rank - 1;
|
|
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
|
|
--suffixStart;
|
|
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
|
|
suffixStart = 0;
|
|
else
|
|
++suffixStart;
|
|
|
|
int64_t chunkElements = 1;
|
|
for (int64_t dim = suffixStart; dim < rank; ++dim)
|
|
chunkElements *= sizes[dim];
|
|
|
|
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
|
|
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
|
|
int64_t sourceChunkOrdinal = 0;
|
|
|
|
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
|
|
if (dim == static_cast<int64_t>(prefixExtents.size())) {
|
|
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
|
|
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
|
|
chunkOffsets[prefixDim] += current[prefixDim];
|
|
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
|
|
return failure();
|
|
++sourceChunkOrdinal;
|
|
return success();
|
|
}
|
|
|
|
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
|
|
current[dim] = index;
|
|
if (failed(visit(visit, dim + 1)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
};
|
|
|
|
if (prefixExtents.empty())
|
|
return callback(baseOffsets, 0, chunkElements);
|
|
return visit(visit, 0);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|