#include "mlir/IR/ValueRange.h" #include "llvm/ADT/STLExtras.h" #include #include "Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" using namespace llvm; using namespace mlir; namespace onnx_mlir { FailureOr getTensorSizeInBytesAttr(Builder& builder, Operation* anchor, mlir::Value value) { auto byteSize = pim::getCheckedShapedTypeSizeInBytes(cast(value.getType()), anchor, "tensor byte size"); if (failed(byteSize)) return failure(); return pim::getCheckedI32Attr(builder, anchor, *byteSize, "tensor byte size"); } Operation* getEarliestUserWithinBlock(mlir::Value value) { auto users = value.getUsers(); assert(!users.empty()); Operation* earliestUser = *users.begin(); for (auto curUser : users) if (curUser->isBeforeInBlock(earliestUser)) earliestUser = curUser; return earliestUser; } SmallVector getOpOperandsSortedByUses(Operation* operation) { auto operandsAndUses = map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair { return {operand, std::distance(operand.use_begin(), operand.use_end())}; }); sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; }); return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); } bool hasLaterUserInBlock(mlir::Value value, Operation* operation) { for (Operation* user : value.getUsers()) { if (user->getBlock() != operation->getBlock()) return true; if (operation->isBeforeInBlock(user)) return true; } return false; } mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) { assert("Only support operations with a single result" && operation->getNumResults() == 1); mlir::Value result = operation->getResult(0); auto resultType = result.getType(); assert("Only support result ShapedType as result type" && isa(resultType)); SmallVector operands = getOpOperandsSortedByUses(operation); auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) { return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation); }); auto bestOperand = validOperands.begin(); if (bestOperand != validOperands.end()) return *bestOperand; auto resultShapedType = cast(resultType); rewriter.setInsertionPoint(operation); return tensor::EmptyOp::create( rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); } LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator, int64_t resultRank, size_t operandCount, ArrayRef operandIndices, ArrayRef sourceOffsets, ArrayRef flatOffsets, ArrayRef flatSizes, ArrayRef flatStrides) { if (operandIndices.size() != sourceOffsets.size()) return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match"); if (flatOffsets.size() != flatSizes.size()) return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths"); if (flatStrides.size() != flatOffsets.size()) return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths"); if (flatOffsets.size() != operandIndices.size() * static_cast(resultRank)) return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment"); for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) { if (operandIndex < 0 || operandIndex >= static_cast(operandCount)) return reconciliator.emitOpError("fragment assembly operand index is out of range"); if (sourceOffsets[fragmentIndex] < 0) return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative"); } return success(); } static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { SmallVector indices(shape.size(), 0); for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { indices[dim] = flatIndex % shape[dim]; flatIndex /= shape[dim]; } return indices; } FailureOr> getStaticSliceOffsetsForElementOffset(Operation* anchor, ShapedType sourceType, ArrayRef fragmentShape, int64_t sourceElementOffset, StringRef fieldName) { if (!sourceType.hasStaticShape()) return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure(); if (sourceElementOffset < 0) return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure(); if (sourceType.getRank() != static_cast(fragmentShape.size())) return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure(); int64_t sourceElementCount = sourceType.getNumElements(); int64_t fragmentElementCount = 1; for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { if (fragmentShape[dim] < 0) return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure(); fragmentElementCount *= fragmentShape[dim]; } if (sourceElementOffset + fragmentElementCount > sourceElementCount) return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure(); SmallVector sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape()); for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim)) return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure(); } return sliceOffsets; } LogicalResult forEachContiguousDestinationChunk(ArrayRef destShape, ArrayRef baseOffsets, ArrayRef sizes, llvm::function_ref, int64_t, int64_t)> callback) { int64_t rank = static_cast(sizes.size()); int64_t suffixStart = rank - 1; while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart]) --suffixStart; if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0) suffixStart = 0; else ++suffixStart; int64_t chunkElements = 1; for (int64_t dim = suffixStart; dim < rank; ++dim) chunkElements *= sizes[dim]; SmallVector prefixExtents(sizes.begin(), sizes.begin() + suffixStart); SmallVector current(prefixExtents.size(), 0); int64_t sourceChunkOrdinal = 0; auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult { if (dim == static_cast(prefixExtents.size())) { SmallVector chunkOffsets(baseOffsets.begin(), baseOffsets.end()); for (int64_t prefixDim = 0; prefixDim < static_cast(current.size()); ++prefixDim) chunkOffsets[prefixDim] += current[prefixDim]; if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements))) return failure(); ++sourceChunkOrdinal; return success(); } for (int64_t index = 0; index < prefixExtents[dim]; ++index) { current[dim] = index; if (failed(visit(visit, dim + 1))) return failure(); } return success(); }; if (prefixExtents.empty()) return callback(baseOffsets, 0, chunkElements); return visit(visit, 0); } } // namespace onnx_mlir