roba
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-26 13:02:38 +02:00
parent 568fd90542
commit 984f362623
13 changed files with 797 additions and 347 deletions
+114
View File
@@ -5,6 +5,7 @@
#include <cassert>
#include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace llvm;
@@ -72,4 +73,117 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
int64_t resultRank,
size_t operandCount,
ArrayRef<int64_t> operandIndices,
ArrayRef<int64_t> sourceOffsets,
ArrayRef<int64_t> flatOffsets,
ArrayRef<int64_t> flatSizes,
ArrayRef<int64_t> flatStrides) {
if (operandIndices.size() != sourceOffsets.size())
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
if (flatOffsets.size() != flatSizes.size())
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
if (flatStrides.size() != flatOffsets.size())
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
if (sourceOffsets[fragmentIndex] < 0)
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
}
return success();
}
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
FailureOr<SmallVector<int64_t, 4>>
getStaticSliceOffsetsForElementOffset(Operation* anchor,
ShapedType sourceType,
ArrayRef<int64_t> fragmentShape,
int64_t sourceElementOffset,
StringRef fieldName) {
if (!sourceType.hasStaticShape())
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
if (sourceElementOffset < 0)
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
int64_t sourceElementCount = sourceType.getNumElements();
int64_t fragmentElementCount = 1;
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
if (fragmentShape[dim] < 0)
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
fragmentElementCount *= fragmentShape[dim];
}
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
}
return sliceOffsets;
}
LogicalResult
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
ArrayRef<int64_t> baseOffsets,
ArrayRef<int64_t> sizes,
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
int64_t rank = static_cast<int64_t>(sizes.size());
int64_t suffixStart = rank - 1;
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
--suffixStart;
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
suffixStart = 0;
else
++suffixStart;
int64_t chunkElements = 1;
for (int64_t dim = suffixStart; dim < rank; ++dim)
chunkElements *= sizes[dim];
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
int64_t sourceChunkOrdinal = 0;
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
if (dim == static_cast<int64_t>(prefixExtents.size())) {
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
chunkOffsets[prefixDim] += current[prefixDim];
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
return failure();
++sourceChunkOrdinal;
return success();
}
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
current[dim] = index;
if (failed(visit(visit, dim + 1)))
return failure();
}
return success();
};
if (prefixExtents.empty())
return callback(baseOffsets, 0, chunkElements);
return visit(visit, 0);
}
} // namespace onnx_mlir