roba
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-26 13:02:38 +02:00
parent 568fd90542
commit 984f362623
13 changed files with 797 additions and 347 deletions
+12 -11
View File
@@ -47,11 +47,13 @@ struct LowerFragmentAssemblyReconciliatorPattern
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = op.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr)
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
@@ -59,23 +61,21 @@ struct LowerFragmentAssemblyReconciliatorPattern
SmallVector<Value> fragmentOperands {adaptor.getInput()};
llvm::append_range(fragmentOperands, adaptor.getFragments());
if (failed(validateFragmentAssemblyMetadata(
op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides)))
return failure();
Value currentOutput =
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return op.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return op.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = fragmentOperands[operandIndex];
@@ -83,20 +83,21 @@ struct LowerFragmentAssemblyReconciliatorPattern
if (!sourceType || !sourceType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
if (failed(extractOffsets))
return failure();
fragment = tensor::ExtractSliceOp::create(rewriter,
op.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, *extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}