roba
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-26 13:02:38 +02:00
parent 568fd90542
commit 984f362623
13 changed files with 797 additions and 347 deletions
+1
View File
@@ -245,6 +245,7 @@ def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
StrAttr:$indexMap,
OptionalAttr<StrAttr>:$mode,
OptionalAttr<DenseI64ArrayAttr>:$fragmentOperandIndices,
OptionalAttr<DenseI64ArrayAttr>:$fragmentSourceOffsets,
OptionalAttr<DenseI64ArrayAttr>:$fragmentStrides,
OptionalAttr<StrAttr>:$conflictPolicy,
OptionalAttr<StrAttr>:$coveragePolicy
+31 -24
View File
@@ -483,21 +483,28 @@ LogicalResult SpatReconciliatorOp::verify() {
return failure();
if (!getFragments().empty())
return emitError("legacy reconciliator does not accept extra fragment operands");
if (getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr())
if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr()
|| getCoveragePolicyAttr())
return emitError("legacy reconciliator does not accept fragment assembly attributes");
return success();
}
auto stridesAttr = getFragmentStridesAttr();
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr();
if (!operandIndicesAttr)
return emitError("fragment assembly reconciliator requires fragment operand indices");
if (!sourceOffsetsAttr)
return emitError("fragment assembly reconciliator requires fragment source offsets");
if (!stridesAttr)
return emitError("fragment assembly reconciliator requires fragment strides");
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
ArrayRef<int64_t> sourceOffsets = sourceOffsetsAttr.asArrayRef();
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
if (strides.size() != offsets.size())
return emitError("fragment stride and offset arrays must have the same length");
if (sourceOffsets.size() != operandIndices.size())
return emitError("fragment source offset count must match fragment operand index count");
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
return emitError("fragment assembly reconciliator requires conflict and coverage policies");
if (getConflictPolicy() != "disjoint")
@@ -519,11 +526,21 @@ LogicalResult SpatReconciliatorOp::verify() {
SmallVector<std::pair<SmallVector<int64_t, 4>, SmallVector<int64_t, 4>>, 8> slices;
slices.reserve(static_cast<size_t>(fragmentCount));
SmallVector<SmallVector<SmallVector<int64_t, 4>, 4>, 8> sizesByOperand(static_cast<size_t>(operandCount));
SmallVector<int64_t, 8> fragmentCountsByOperand(static_cast<size_t>(operandCount), 0);
auto expandFlatElementIndex = [](int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
};
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= operandCount)
return emitError("fragment assembly operand index is out of range");
if (sourceOffsets[fragmentIndex] < 0)
return emitError("fragment assembly source offsets must be nonnegative");
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
if (!operandType || !operandType.hasStaticShape())
@@ -541,7 +558,17 @@ LogicalResult SpatReconciliatorOp::verify() {
fragmentSizes.push_back(sizes[flatIndex]);
}
sizesByOperand[static_cast<size_t>(operandIndex)].push_back(fragmentSizes);
++fragmentCountsByOperand[static_cast<size_t>(operandIndex)];
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim)
fragmentElements *= fragmentSizes[dim];
if (sourceOffsets[fragmentIndex] + fragmentElements > operandType.getNumElements())
return emitError("fragment assembly source offset exceeds the operand bounds");
SmallVector<int64_t, 4> sourceSliceOffsets =
expandFlatElementIndex(sourceOffsets[fragmentIndex], operandType.getShape());
for (int64_t dim = 0; dim < rank; ++dim)
if (sourceSliceOffsets[dim] + fragmentSizes[dim] > operandType.getDimSize(dim))
return emitError("fragment assembly source offset must describe a valid unit-stride slice");
for (const auto& [existingOffsets, existingSizes] : slices) {
bool overlaps = true;
@@ -562,28 +589,8 @@ LogicalResult SpatReconciliatorOp::verify() {
}
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
if (sizesByOperand[static_cast<size_t>(operandIndex)].empty())
if (fragmentCountsByOperand[static_cast<size_t>(operandIndex)] == 0)
return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment");
auto operandType = cast<RankedTensorType>(operands[operandIndex].getType());
ArrayRef<int64_t> operandShape = operandType.getShape();
auto& fragmentShapes = sizesByOperand[static_cast<size_t>(operandIndex)];
if (fragmentShapes.size() == 1) {
if (!llvm::equal(operandShape, fragmentShapes.front()))
return emitError("single-fragment reconciliator operand shape must match declared fragment size");
continue;
}
ArrayRef<int64_t> fragmentShape = fragmentShapes.front();
for (ArrayRef<int64_t> otherShape : fragmentShapes)
if (!llvm::equal(fragmentShape, otherShape))
return emitError("packed reconciliator operand requires equal fragment sizes per operand");
if (llvm::equal(operandShape, fragmentShape))
continue;
if (!llvm::equal(operandShape.drop_front(), fragmentShape.drop_front()))
return emitError("packed reconciliator operand must match fragment shape on non-packed dimensions");
if (operandShape.front() != static_cast<int64_t>(fragmentShapes.size()) * fragmentShape.front())
return emitError("packed reconciliator operand first dimension must equal fragment_count * fragment_size");
}
if (getCoveragePolicy() == "complete") {
@@ -194,6 +194,7 @@ struct MaterializedClass {
SmallVector<Value, 8> weights;
SmallVector<Value, 8> inputs;
SmallVector<Value, 4> hostOutputs;
DenseMap<Value, unsigned> publicationOutputToResultIndex;
DenseMap<Value, BlockArgument> weightArgs;
DenseMap<Value, BlockArgument> inputArgs;
DenseMap<Value, unsigned> hostOutputToResultIndex;
@@ -307,11 +308,9 @@ struct PendingProjectedHostOutputFragment {
Value originalOutput;
ClassId sourceClass = 0;
ProducerKey producerKey;
Value operand;
RankedTensorType operandType;
RankedTensorType fragmentType;
int64_t packedFragmentIndex = -1;
int64_t currentLane = -1;
Value publicationValue;
int64_t sourceFragmentOrdinal = 0;
int64_t sourceElementOffset = 0;
SmallVector<int64_t, 4> offsets;
SmallVector<int64_t, 4> sizes;
SmallVector<int64_t, 4> strides;
@@ -379,6 +378,9 @@ LogicalResult localizeCapturesInClonedOp(MaterializerState& state,
Operation& clonedOp,
IRMapping* mapper = nullptr);
LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass);
void createDim0ParallelInsertSlice(
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset);
Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc);
bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch,
const AffineProjectedInputSliceMatch& match,
ProducerKey producer,
@@ -627,6 +629,31 @@ ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, Co
return logicalInstance;
}
FailureOr<unsigned>
getPublicationLaneForProducerKey(MaterializerState& state, const MaterializedClass& sourceClass, ProducerKey key) {
if (!sourceClass.isBatch)
return 0;
ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, key.instance);
auto cpuIt = state.schedule.computeToCpuMap.find(scheduledProducer);
if (cpuIt == state.schedule.computeToCpuMap.end()) {
sourceClass.op->emitError("projected packed host publication could not resolve the producer CPU for a publication lane")
<< " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount
<< " resultIndex=" << key.resultIndex;
return failure();
}
auto laneIt = sourceClass.cpuToLane.find(cpuIt->second);
if (laneIt == sourceClass.cpuToLane.end()) {
sourceClass.op->emitError("projected packed host publication could not map a producer key to a publication lane")
<< " cpu=" << cpuIt->second << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount
<< " resultIndex=" << key.resultIndex;
return failure();
}
return laneIt->second;
}
SmallVector<ProducerKey, 4>
collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
// Destination collection works in the materializer's logical one-lane key domain.
@@ -1043,6 +1070,7 @@ LogicalResult collectHostOutputs(MaterializerState& state) {
for (MaterializedClass& materializedClass : state.classes) {
materializedClass.hostOutputs.clear();
materializedClass.hostOutputToResultIndex.clear();
materializedClass.publicationOutputToResultIndex.clear();
}
state.hostOutputOwners.clear();
@@ -1150,48 +1178,6 @@ void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
state.rewriter.setInsertionPointToEnd(&funcBlock);
}
FailureOr<ClassId> createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) {
DenseSet<CpuId> usedCpus;
for (const auto& [cpu, _] : state.cpuToClass)
usedCpus.insert(cpu);
CpuId assemblyCpu = 0;
while (usedCpus.contains(assemblyCpu))
++assemblyCpu;
setInsertionPointForNewMaterializedOp(state);
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
return state.func.emitError("projected host assembly class requires a static ranked tensor output");
auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {});
compute.getProperties().setOperandSegmentSizes({0, 0});
auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id");
if (failed(coreIdAttr))
return failure();
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
Block* body = state.rewriter.createBlock(&compute.getBody());
state.rewriter.setInsertionPointToEnd(body);
Value placeholder =
tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder});
state.rewriter.setInsertionPointAfter(compute.getOperation());
MaterializedClass materializedClass;
materializedClass.id = state.classes.size();
materializedClass.cpus.push_back(assemblyCpu);
materializedClass.op = compute.getOperation();
materializedClass.body = body;
materializedClass.hostOutputToResultIndex[originalOutput] = 0;
materializedClass.hostOutputs.push_back(originalOutput);
state.cpuToClass[assemblyCpu] = materializedClass.id;
state.hostOutputOwners[originalOutput] = materializedClass.id;
state.classes.push_back(std::move(materializedClass));
return state.classes.back().id;
}
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
auto it = materializedClass.weightArgs.find(weight);
if (it != materializedClass.weightArgs.end())
@@ -1226,13 +1212,125 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
materializedClass.inputArgs[input] = std::get<1>(*arg);
return std::get<1>(*arg);
}
if (auto compute = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op)) {
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
assert(arg && "expected compute_batch body while inserting an input argument");
materializedClass.inputArgs[input] = std::get<1>(*arg);
return std::get<1>(*arg);
auto compute = cast<SpatScheduledComputeBatch>(materializedClass.op);
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
assert(arg && "expected compute_batch body while inserting an input argument");
materializedClass.inputArgs[input] = std::get<1>(*arg);
return std::get<1>(*arg);
}
void refreshPendingProjectedHostOutputPublicationValues(MaterializerState& state,
Operation* oldOwner,
Operation* newOwner) {
if (!oldOwner || oldOwner == newOwner)
return;
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) {
auto publicationResult = dyn_cast_or_null<OpResult>(fragment.publicationValue);
if (!publicationResult || publicationResult.getOwner() != oldOwner)
publicationResult = OpResult();
else
fragment.publicationValue = newOwner->getResult(publicationResult.getResultNumber());
if (auto originalResult = dyn_cast_or_null<OpResult>(fragment.originalOutput); originalResult
&& originalResult.getOwner() == oldOwner) {
fragment.originalOutput = newOwner->getResult(originalResult.getResultNumber());
}
if (fragment.producerKey.instance.op == oldOwner)
fragment.producerKey.instance.op = newOwner;
}
llvm_unreachable("Cannot reach here");
}
FailureOr<Value> appendScalarPublicationResult(MaterializerState& state,
MaterializedClass& materializedClass,
Value payload,
Location loc) {
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
if (existing != materializedClass.publicationOutputToResultIndex.end())
return materializedClass.op->getResult(existing->second);
auto compute = dyn_cast<SpatScheduledCompute>(materializedClass.op);
if (!compute)
return materializedClass.op->emitError("scalar publication result requires spat.scheduled_compute owner");
auto payloadType = dyn_cast<RankedTensorType>(payload.getType());
if (!payloadType || !payloadType.hasStaticShape())
return materializedClass.op->emitError("scalar publication result requires static ranked tensor payload");
FailureOr<std::tuple<OpResult, SpatScheduledCompute>> inserted =
compute.insertOutput(state.rewriter, compute.getNumResults(), payloadType, loc);
if (failed(inserted))
return materializedClass.op->emitError("failed to append scalar publication result");
Operation* oldOp = materializedClass.op;
auto [result, newCompute] = *inserted;
materializedClass.op = newCompute.getOperation();
materializedClass.body = &newCompute.getBody().front();
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
if (!yieldOp)
return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result");
state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); });
return result;
}
FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
MaterializedClass& materializedClass,
Value payload,
Location loc) {
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
if (existing != materializedClass.publicationOutputToResultIndex.end())
return materializedClass.op->getResult(existing->second);
auto batch = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op);
if (!batch)
return materializedClass.op->emitError("batch publication result requires spat.scheduled_compute_batch owner");
auto payloadType = dyn_cast<RankedTensorType>(payload.getType());
if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0)
return materializedClass.op->emitError(
"batch publication result requires a static ranked tensor payload with rank > 0");
SmallVector<int64_t, 4> publishedShape(payloadType.getShape());
publishedShape[0] *= static_cast<int64_t>(materializedClass.cpus.size());
auto publishedType =
RankedTensorType::get(publishedShape, payloadType.getElementType(), payloadType.getEncoding());
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>> inserted =
batch.insertOutput(state.rewriter, batch.getNumResults(), publishedType, loc);
if (failed(inserted))
return materializedClass.op->emitError("failed to append batch publication result");
Operation* oldOp = materializedClass.op;
auto [result, outputArg, newBatch] = *inserted;
materializedClass.op = newBatch.getOperation();
materializedClass.body = &newBatch.getBody().front();
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
auto inParallelOp = dyn_cast<SpatInParallelOp>(materializedClass.body->getTerminator());
auto laneArg = newBatch.getLaneArgument();
if (!laneArg)
return materializedClass.op->emitError("batch publication result requires a lane argument");
if (!inParallelOp) {
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 0)
return materializedClass.op->emitError(
"batch publication result requires either spat.in_parallel or an empty spat.yield terminator");
state.rewriter.setInsertionPoint(yieldOp);
inParallelOp = SpatInParallelOp::create(state.rewriter, loc);
state.rewriter.eraseOp(yieldOp);
}
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
Value firstOffset =
scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc);
createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset);
return result;
}
// -----------------------------------------------------------------------------
@@ -5520,6 +5618,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
return failure();
}
FailureOr<Value> publicationResult = appendScalarPublicationResult(state, sourceClass, packed, loc);
if (failed(publicationResult))
return failure();
int64_t fragmentElementCount = fragmentType.getNumElements();
for (auto [runIndex, slot] : llvm::enumerate(run)) {
if (slot.peers.size() != 1) {
sourceClass.op->emitError("projected scalar host output publication expects scalar one-peer run slots");
@@ -5553,11 +5657,9 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
originalOutput,
sourceClass.id,
ProducerKey {peer, resultIndex},
packed,
cast<RankedTensorType>(packed.getType()),
fragmentType,
static_cast<int64_t>(runIndex),
*publicationResult,
static_cast<int64_t>(runIndex),
static_cast<int64_t>(runIndex) * fragmentElementCount,
SmallVector<int64_t, 4>(*offsets),
SmallVector<int64_t, 4>(*sizes),
SmallVector<int64_t, 4>(*strides),
@@ -5609,7 +5711,10 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
if (fragmentType == originalOutput.getType())
return false;
bool operandIsDim0Packed = false;
FailureOr<Value> publicationResult = appendBatchPublicationResult(state, sourceClass, packed, loc);
if (failed(publicationResult))
return failure();
if (packedType != fragmentType) {
if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast<int64_t>(keys.size()) != 0)
return sourceClass.op->emitError(
@@ -5622,9 +5727,16 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
return sourceClass.op->emitError(
"projected packed host publication fragment shape does not match projected slice size")
<< " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size();
operandIsDim0Packed = true;
}
int64_t payloadElementCount = packedType.getNumElements();
int64_t fragmentElementCount = fragmentType.getNumElements();
int64_t fragmentsPerPublishedPayload = payloadElementCount / fragmentElementCount;
if (fragmentsPerPublishedPayload <= 0 || static_cast<int64_t>(keys.size()) % fragmentsPerPublishedPayload != 0)
return sourceClass.op->emitOpError(
"projected packed host publication requires a deterministic publication packing layout")
<< " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size();
for (auto [fragmentIndex, key] : llvm::enumerate(keys)) {
if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex || key.instance.laneCount != 1)
return sourceClass.op->emitError("projected packed host publication requires one-lane keys from one producer result");
@@ -5642,15 +5754,19 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
return sourceClass.op->emitError(
"projected packed host publication requires one operand to map to a consistent fragment shape");
FailureOr<unsigned> publishedLaneIndex = getPublicationLaneForProducerKey(state, sourceClass, key);
if (failed(publishedLaneIndex))
return failure();
int64_t localFragmentOffsetWithinPublishedPayload =
(static_cast<int64_t>(fragmentIndex) % fragmentsPerPublishedPayload) * fragmentElementCount;
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
originalOutput,
sourceClass.id,
key,
packed,
packedType,
fragmentType,
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
*publicationResult,
static_cast<int64_t>(fragmentIndex),
static_cast<int64_t>(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload,
SmallVector<int64_t, 4>(*offsets),
SmallVector<int64_t, 4>(*sizes),
SmallVector<int64_t, 4>(*strides),
@@ -5678,24 +5794,23 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
< reinterpret_cast<uintptr_t>(rhs.getAsOpaquePointer());
});
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
if (!returnOp)
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
for (Value originalOutput : outputs) {
auto ownerIt = state.hostOutputOwners.find(originalOutput);
if (ownerIt == state.hostOutputOwners.end()) {
Operation* anchor = originalOutput.getDefiningOp() ? originalOutput.getDefiningOp() : state.func.getOperation();
return anchor->emitError("missing host owner for projected host output fragments");
}
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
return ownerClass->op->emitError("projected host output must have static ranked tensor type");
return state.func.emitError("projected host output must have static ranked tensor type");
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
const PendingProjectedHostOutputFragment* rhs) {
if (lhs->sourceLane != rhs->sourceLane)
return lhs->sourceLane < rhs->sourceLane;
if (lhs->publicationValue != rhs->publicationValue)
return reinterpret_cast<uintptr_t>(lhs->publicationValue.getAsOpaquePointer())
< reinterpret_cast<uintptr_t>(rhs->publicationValue.getAsOpaquePointer());
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
if (lhs->sourceClass != rhs->sourceClass)
return lhs->sourceClass < rhs->sourceClass;
return std::lexicographical_compare(lhs->offsets.begin(),
@@ -5704,240 +5819,36 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
rhs->offsets.end());
});
bool allFromSameSourceClass =
llvm::all_of(fragments, [&](const PendingProjectedHostOutputFragment* fragment) {
return fragment->sourceClass == fragments.front()->sourceClass;
});
if (allFromSameSourceClass) {
ownerClass = &state.classes[fragments.front()->sourceClass];
state.hostOutputOwners[originalOutput] = ownerClass->id;
} else {
if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput))
goto owner_selected;
FailureOr<ClassId> createdOwner =
createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc);
if (failed(createdOwner))
return failure();
ownerClass = &state.classes[*createdOwner];
}
owner_selected:
if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) {
auto sourceBatch = dyn_cast<SpatComputeBatch>(fragments.front()->producerKey.instance.op);
auto batch = dyn_cast<SpatScheduledComputeBatch>(ownerClass->op);
auto inParallelOp = dyn_cast_or_null<SpatInParallelOp>(ownerClass->body->getTerminator());
auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput);
if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end())
return ownerClass->op->emitError("missing batch host assembly state for projected host output");
FailureOr<tensor::ParallelInsertSliceOp> sourceProjection =
getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex);
std::optional<BlockArgument> sourceLaneArg = sourceBatch.getLaneArgument();
if (failed(sourceProjection) || !sourceLaneArg)
return ownerClass->op->emitError(
"direct batch host output assembly requires the source batch projection metadata");
auto outputArg = batch.getOutputArgument(resultIt->second);
auto laneArg = batch.getLaneArgument();
if (!outputArg || !laneArg)
return ownerClass->op->emitError("missing compute_batch output block argument for projected host output");
if (fragments.size() != ownerClass->cpus.size())
return ownerClass->op->emitError(
"direct batch host output assembly expects exactly one fragment per materialized lane");
SmallVector<PendingProjectedHostOutputFragment*, 8> fragmentsByLane(ownerClass->cpus.size(), nullptr);
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane;
if (currentLane < 0 || currentLane >= static_cast<int64_t>(fragmentsByLane.size()))
return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds");
if (fragmentsByLane[currentLane])
return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane");
fragmentsByLane[currentLane] = fragmentRecord;
}
if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; }))
return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes");
FailureOr<SmallVector<int64_t, 4>> firstSizes =
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
FailureOr<SmallVector<int64_t, 4>> firstStrides =
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
if (failed(firstSizes) || failed(firstStrides))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
SmallVector<int64_t, 4> referenceSizes(*firstSizes);
SmallVector<int64_t, 4> referenceStrides(*firstStrides);
Value laneOperand;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
FailureOr<SmallVector<int64_t, 4>> fragmentSizes =
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane);
FailureOr<SmallVector<int64_t, 4>> fragmentStrides =
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane);
if (failed(fragmentSizes) || failed(fragmentStrides))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
if (SmallVector<int64_t, 4>(*fragmentSizes) != referenceSizes
|| SmallVector<int64_t, 4>(*fragmentStrides) != referenceStrides)
return ownerClass->op->emitError(
"direct batch host output assembly expects a uniform fragment shape and strides");
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
Value operand;
if (std::optional<Value> availableValue =
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
operand = *availableValue;
} else {
operand = fragmentRecord->operand;
}
if (!isValueLegalInMaterializedClassBody(operand, *ownerClass))
return ownerClass->op->emitError(
"projected batch host output assembly requires source-local fragment operands");
if (laneOperand && laneOperand != operand)
return ownerClass->op->emitError(
"direct batch host output assembly expects one shared lane-local fragment producer");
laneOperand = operand;
}
SmallVector<OpFoldResult, 4> mixedOffsets;
mixedOffsets.reserve(referenceSizes.size());
for (size_t dim = 0; dim < referenceSizes.size(); ++dim) {
SmallVector<int64_t, 8> offsetsByLane;
offsetsByLane.reserve(fragmentsByLane.size());
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
FailureOr<SmallVector<int64_t, 4>> fragmentOffsets =
evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane);
if (failed(fragmentOffsets))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets");
offsetsByLane.push_back((*fragmentOffsets)[dim]);
}
mixedOffsets.push_back(allEqual(offsetsByLane)
? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front()))
: OpFoldResult(createLaneIndexedIndexValue(
state, *ownerClass, ArrayRef<int64_t>(offsetsByLane), fragments.front()->loc)));
}
state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second);
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
tensor::ParallelInsertSliceOp::create(state.rewriter,
fragments.front()->loc,
laneOperand,
*outputArg,
mixedOffsets,
getStaticIndexAttrs(state.rewriter, referenceSizes),
getStaticIndexAttrs(state.rewriter, referenceStrides));
continue;
}
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
state.rewriter.setInsertionPoint(returnOp);
Location loc = fragments.front()->loc;
SmallVector<Value, 16> reconciliatorOperands;
SmallVector<int64_t, 16> fragmentOperandIndices;
SmallVector<int64_t, 16> fragmentSourceOffsets;
SmallVector<int64_t, 64> flatOffsets;
SmallVector<int64_t, 64> flatSizes;
SmallVector<int64_t, 64> flatStrides;
DenseMap<Value, int64_t> operandIndicesByValue;
DenseSet<ClassId> emittedBatchForwarding;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
Value operand;
if (std::optional<Value> availableValue =
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
operand = *availableValue;
} else if (fragmentRecord->sourceClass == sourceClass.id) {
operand = fragmentRecord->operand;
} else {
return sourceClass.op->emitError(
"projected host output fragment assembly is missing source-visible fragment operands before finalization");
}
if (fragmentRecord->sourceClass != ownerClass->id) {
if (sourceClass.isBatch && !ownerClass->isBatch) {
if (!emittedBatchForwarding.insert(sourceClass.id).second) {
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
if (!localized)
return ownerClass->op->emitError(
"projected host output fragment assembly is missing forwarded batch fragments");
operand = *localized;
} else {
SmallVector<ProducerKey, 8> forwardedKeys;
forwardedKeys.reserve(sourceClass.cpus.size());
Value forwardedPayload = fragmentRecord->operand;
for (PendingProjectedHostOutputFragment* candidate : fragments) {
if (candidate->sourceClass != sourceClass.id)
continue;
if (candidate->operand != forwardedPayload)
return ownerClass->op->emitError(
"projected host output batch forwarding expects one shared batch payload per source class");
forwardedKeys.push_back(candidate->producerKey);
}
llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) {
return lhs.instance.laneStart < rhs.instance.laneStart;
});
if (failed(emitClassToClassCommunication(
state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc)))
return failure();
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
if (!localized)
return ownerClass->op->emitError(
"projected host output fragment assembly failed to recover forwarded batch fragment");
operand = *localized;
}
} else {
MessageVector messages;
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
sourceClass.cpus.front(),
"projected host output source core id");
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
ownerClass->cpus.front(),
"projected host output target core id");
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
return failure();
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
return failure();
operand = appendReceive(state,
*ownerClass,
cast<RankedTensorType>(operand.getType()),
messages,
fragmentRecord->loc);
}
} else if (!ownerClass->isBatch) {
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
state,
*ownerClass,
operand,
ownerClass->op,
"projected host output assembly tried to reuse a non-local fragment tensor");
if (failed(localOperand))
return failure();
operand = *localOperand;
}
Value operand = fragmentRecord->publicationValue;
auto [operandIt, inserted] =
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(reconciliatorOperands.size()));
if (inserted)
reconciliatorOperands.push_back(operand);
fragmentOperandIndices.push_back(operandIt->second);
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
llvm::append_range(flatOffsets, fragmentRecord->offsets);
llvm::append_range(flatSizes, fragmentRecord->sizes);
llvm::append_range(flatStrides, fragmentRecord->strides);
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
if (!operandType || !operandType.hasStaticShape())
return ownerClass->op->emitError("projected host output assembly requires static ranked tensor operands");
if (fragmentRecord->packedFragmentIndex >= 0) {
int64_t fragmentSize0 = fragmentRecord->fragmentType.getDimSize(0);
if (fragmentSize0 <= 0 || operandType.getRank() == 0)
return ownerClass->op->emitError("packed projected host output assembly requires ranked fragment operands");
int64_t start = fragmentRecord->packedFragmentIndex * fragmentSize0;
int64_t end = start + fragmentSize0;
if (start < 0 || end > operandType.getDimSize(0))
return ownerClass->op->emitError("packed projected host output fragment index is out of bounds");
}
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
}
if (reconciliatorOperands.empty())
return ownerClass->op->emitError("missing projected host output fragments");
return state.func.emitError("missing projected host output fragments");
Value input = reconciliatorOperands.front();
ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front();
@@ -5954,12 +5865,12 @@ owner_selected:
state.rewriter.getStringAttr("identity"),
state.rewriter.getStringAttr("fragment_assembly"),
state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices),
state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets),
state.rewriter.getDenseI64ArrayAttr(flatStrides),
state.rewriter.getStringAttr("disjoint"),
state.rewriter.getStringAttr("complete"));
if (failed(setHostOutputValue(state, *ownerClass, originalOutput, reconciliator.getOutput())))
return failure();
state.hostReplacements[originalOutput] = reconciliator.getOutput();
}
return success();