Files
Raptor/src/PIM/Conversion/SpatialToPim/Common.cpp
T
NiccoloN 984f362623
Validate Operations / validate-operations (push) Waiting to run
roba
2026-06-26 13:02:38 +02:00

190 lines
8.1 KiB
C++

#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
FailureOr<IntegerAttr> getTensorSizeInBytesAttr(Builder& builder, Operation* anchor, mlir::Value value) {
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(cast<ShapedType>(value.getType()), anchor, "tensor byte size");
if (failed(byteSize))
return failure();
return pim::getCheckedI32Attr(builder, anchor, *byteSize, "tensor byte size");
}
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();
assert(!users.empty());
Operation* earliestUser = *users.begin();
for (auto curUser : users)
if (curUser->isBeforeInBlock(earliestUser))
earliestUser = curUser;
return earliestUser;
}
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses =
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
return {operand, std::distance(operand.use_begin(), operand.use_end())};
});
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
for (Operation* user : value.getUsers()) {
if (user->getBlock() != operation->getBlock())
return true;
if (operation->isBeforeInBlock(user))
return true;
}
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0);
auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
});
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())
return *bestOperand;
auto resultShapedType = cast<ShapedType>(resultType);
rewriter.setInsertionPoint(operation);
return tensor::EmptyOp::create(
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
int64_t resultRank,
size_t operandCount,
ArrayRef<int64_t> operandIndices,
ArrayRef<int64_t> sourceOffsets,
ArrayRef<int64_t> flatOffsets,
ArrayRef<int64_t> flatSizes,
ArrayRef<int64_t> flatStrides) {
if (operandIndices.size() != sourceOffsets.size())
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
if (flatOffsets.size() != flatSizes.size())
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
if (flatStrides.size() != flatOffsets.size())
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
if (sourceOffsets[fragmentIndex] < 0)
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
}
return success();
}
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
FailureOr<SmallVector<int64_t, 4>>
getStaticSliceOffsetsForElementOffset(Operation* anchor,
ShapedType sourceType,
ArrayRef<int64_t> fragmentShape,
int64_t sourceElementOffset,
StringRef fieldName) {
if (!sourceType.hasStaticShape())
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
if (sourceElementOffset < 0)
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
int64_t sourceElementCount = sourceType.getNumElements();
int64_t fragmentElementCount = 1;
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
if (fragmentShape[dim] < 0)
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
fragmentElementCount *= fragmentShape[dim];
}
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
}
return sliceOffsets;
}
LogicalResult
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
ArrayRef<int64_t> baseOffsets,
ArrayRef<int64_t> sizes,
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
int64_t rank = static_cast<int64_t>(sizes.size());
int64_t suffixStart = rank - 1;
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
--suffixStart;
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
suffixStart = 0;
else
++suffixStart;
int64_t chunkElements = 1;
for (int64_t dim = suffixStart; dim < rank; ++dim)
chunkElements *= sizes[dim];
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
int64_t sourceChunkOrdinal = 0;
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
if (dim == static_cast<int64_t>(prefixExtents.size())) {
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
chunkOffsets[prefixDim] += current[prefixDim];
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
return failure();
++sourceChunkOrdinal;
return success();
}
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
current[dim] = index;
if (failed(visit(visit, dim + 1)))
return failure();
}
return success();
};
if (prefixExtents.empty())
return callback(baseOffsets, 0, chunkElements);
return visit(visit, 0);
}
} // namespace onnx_mlir