cose
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-25 18:57:12 +02:00
parent be0bcc9dcc
commit 568fd90542
6 changed files with 647 additions and 179 deletions
@@ -306,10 +306,12 @@ struct ProjectedExtractReplacement {
struct PendingProjectedHostOutputFragment {
Value originalOutput;
ClassId sourceClass = 0;
ProducerKey producerKey;
Value operand;
RankedTensorType operandType;
RankedTensorType fragmentType;
int64_t packedFragmentIndex = -1;
int64_t currentLane = -1;
SmallVector<int64_t, 4> offsets;
SmallVector<int64_t, 4> sizes;
SmallVector<int64_t, 4> strides;
@@ -1137,6 +1139,59 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) {
return success();
}
void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
Block& funcBlock = state.func.getBody().front();
for (Operation& op : funcBlock) {
if (state.oldComputeOps.contains(&op)) {
state.rewriter.setInsertionPoint(&op);
return;
}
}
state.rewriter.setInsertionPointToEnd(&funcBlock);
}
FailureOr<ClassId> createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) {
DenseSet<CpuId> usedCpus;
for (const auto& [cpu, _] : state.cpuToClass)
usedCpus.insert(cpu);
CpuId assemblyCpu = 0;
while (usedCpus.contains(assemblyCpu))
++assemblyCpu;
setInsertionPointForNewMaterializedOp(state);
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
return state.func.emitError("projected host assembly class requires a static ranked tensor output");
auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {});
compute.getProperties().setOperandSegmentSizes({0, 0});
auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id");
if (failed(coreIdAttr))
return failure();
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
Block* body = state.rewriter.createBlock(&compute.getBody());
state.rewriter.setInsertionPointToEnd(body);
Value placeholder =
tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder});
state.rewriter.setInsertionPointAfter(compute.getOperation());
MaterializedClass materializedClass;
materializedClass.id = state.classes.size();
materializedClass.cpus.push_back(assemblyCpu);
materializedClass.op = compute.getOperation();
materializedClass.body = body;
materializedClass.hostOutputToResultIndex[originalOutput] = 0;
materializedClass.hostOutputs.push_back(originalOutput);
state.cpuToClass[assemblyCpu] = materializedClass.id;
state.hostOutputOwners[originalOutput] = materializedClass.id;
state.classes.push_back(std::move(materializedClass));
return state.classes.back().id;
}
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
auto it = materializedClass.weightArgs.find(weight);
if (it != materializedClass.weightArgs.end())
@@ -1897,6 +1952,14 @@ FailureOr<SmallVector<OpFoldResult, 4>> buildProjectedFragmentOffsetsInClass(Mat
return fragmentOffsets;
}
SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
Value createDim0InsertSlice(
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
auto fragmentType = cast<RankedTensorType>(fragment.getType());
@@ -3639,6 +3702,9 @@ LogicalResult appendSend(MaterializerState& state,
if (sourceClass.isBatch) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (messages.size() != sourceClass.cpus.size())
return sourceClass.op->emitError("batch send expects exactly one message per materialized lane")
<< " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size();
Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc);
@@ -3686,6 +3752,11 @@ Value appendReceive(
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
if (targetClass.isBatch) {
if (messages.size() != targetClass.cpus.size()) {
targetClass.op->emitOpError("batch receive expects exactly one message per materialized lane")
<< " messageCount=" << messages.size() << " laneCount=" << targetClass.cpus.size();
return Value();
}
Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc);
@@ -5481,10 +5552,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
originalOutput,
sourceClass.id,
ProducerKey {peer, resultIndex},
packed,
cast<RankedTensorType>(packed.getType()),
fragmentType,
static_cast<int64_t>(runIndex),
static_cast<int64_t>(runIndex),
SmallVector<int64_t, 4>(*offsets),
SmallVector<int64_t, 4>(*sizes),
SmallVector<int64_t, 4>(*strides),
@@ -5572,10 +5645,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
originalOutput,
sourceClass.id,
key,
packed,
packedType,
fragmentType,
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
static_cast<int64_t>(fragmentIndex),
SmallVector<int64_t, 4>(*offsets),
SmallVector<int64_t, 4>(*sizes),
SmallVector<int64_t, 4>(*strides),
@@ -5611,16 +5686,6 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
}
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
if (ownerClass->isBatch) {
auto scalarOwnerIt = llvm::find_if(state.classes, [](const MaterializedClass& candidate) {
return !candidate.isBatch;
});
if (scalarOwnerIt == state.classes.end())
return ownerClass->op->emitError(
"projected host output finalization requires a scalar assembly class when the preferred host owner is batch");
ownerClass = &*scalarOwnerIt;
state.hostOutputOwners[originalOutput] = ownerClass->id;
}
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
@@ -5646,6 +5711,119 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
if (allFromSameSourceClass) {
ownerClass = &state.classes[fragments.front()->sourceClass];
state.hostOutputOwners[originalOutput] = ownerClass->id;
} else {
if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput))
goto owner_selected;
FailureOr<ClassId> createdOwner =
createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc);
if (failed(createdOwner))
return failure();
ownerClass = &state.classes[*createdOwner];
}
owner_selected:
if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) {
auto sourceBatch = dyn_cast<SpatComputeBatch>(fragments.front()->producerKey.instance.op);
auto batch = dyn_cast<SpatScheduledComputeBatch>(ownerClass->op);
auto inParallelOp = dyn_cast_or_null<SpatInParallelOp>(ownerClass->body->getTerminator());
auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput);
if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end())
return ownerClass->op->emitError("missing batch host assembly state for projected host output");
FailureOr<tensor::ParallelInsertSliceOp> sourceProjection =
getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex);
std::optional<BlockArgument> sourceLaneArg = sourceBatch.getLaneArgument();
if (failed(sourceProjection) || !sourceLaneArg)
return ownerClass->op->emitError(
"direct batch host output assembly requires the source batch projection metadata");
auto outputArg = batch.getOutputArgument(resultIt->second);
auto laneArg = batch.getLaneArgument();
if (!outputArg || !laneArg)
return ownerClass->op->emitError("missing compute_batch output block argument for projected host output");
if (fragments.size() != ownerClass->cpus.size())
return ownerClass->op->emitError(
"direct batch host output assembly expects exactly one fragment per materialized lane");
SmallVector<PendingProjectedHostOutputFragment*, 8> fragmentsByLane(ownerClass->cpus.size(), nullptr);
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane;
if (currentLane < 0 || currentLane >= static_cast<int64_t>(fragmentsByLane.size()))
return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds");
if (fragmentsByLane[currentLane])
return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane");
fragmentsByLane[currentLane] = fragmentRecord;
}
if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; }))
return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes");
FailureOr<SmallVector<int64_t, 4>> firstSizes =
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
FailureOr<SmallVector<int64_t, 4>> firstStrides =
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
if (failed(firstSizes) || failed(firstStrides))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
SmallVector<int64_t, 4> referenceSizes(*firstSizes);
SmallVector<int64_t, 4> referenceStrides(*firstStrides);
Value laneOperand;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
FailureOr<SmallVector<int64_t, 4>> fragmentSizes =
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane);
FailureOr<SmallVector<int64_t, 4>> fragmentStrides =
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane);
if (failed(fragmentSizes) || failed(fragmentStrides))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
if (SmallVector<int64_t, 4>(*fragmentSizes) != referenceSizes
|| SmallVector<int64_t, 4>(*fragmentStrides) != referenceStrides)
return ownerClass->op->emitError(
"direct batch host output assembly expects a uniform fragment shape and strides");
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
Value operand;
if (std::optional<Value> availableValue =
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
operand = *availableValue;
} else {
operand = fragmentRecord->operand;
}
if (!isValueLegalInMaterializedClassBody(operand, *ownerClass))
return ownerClass->op->emitError(
"projected batch host output assembly requires source-local fragment operands");
if (laneOperand && laneOperand != operand)
return ownerClass->op->emitError(
"direct batch host output assembly expects one shared lane-local fragment producer");
laneOperand = operand;
}
SmallVector<OpFoldResult, 4> mixedOffsets;
mixedOffsets.reserve(referenceSizes.size());
for (size_t dim = 0; dim < referenceSizes.size(); ++dim) {
SmallVector<int64_t, 8> offsetsByLane;
offsetsByLane.reserve(fragmentsByLane.size());
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
FailureOr<SmallVector<int64_t, 4>> fragmentOffsets =
evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane);
if (failed(fragmentOffsets))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets");
offsetsByLane.push_back((*fragmentOffsets)[dim]);
}
mixedOffsets.push_back(allEqual(offsetsByLane)
? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front()))
: OpFoldResult(createLaneIndexedIndexValue(
state, *ownerClass, ArrayRef<int64_t>(offsetsByLane), fragments.front()->loc)));
}
state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second);
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
tensor::ParallelInsertSliceOp::create(state.rewriter,
fragments.front()->loc,
laneOperand,
*outputArg,
mixedOffsets,
getStaticIndexAttrs(state.rewriter, referenceSizes),
getStaticIndexAttrs(state.rewriter, referenceStrides));
continue;
}
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
@@ -5656,28 +5834,73 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
SmallVector<int64_t, 64> flatSizes;
SmallVector<int64_t, 64> flatStrides;
DenseMap<Value, int64_t> operandIndicesByValue;
DenseSet<ClassId> emittedBatchForwarding;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
Value operand = fragmentRecord->operand;
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
Value operand;
if (std::optional<Value> availableValue =
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
operand = *availableValue;
} else if (fragmentRecord->sourceClass == sourceClass.id) {
operand = fragmentRecord->operand;
} else {
return sourceClass.op->emitError(
"projected host output fragment assembly is missing source-visible fragment operands before finalization");
}
if (fragmentRecord->sourceClass != ownerClass->id) {
if (sourceClass.isBatch || ownerClass->isBatch)
return sourceClass.op->emitError(
"projected host output fragment assembly requires scalarized cross-class operands before finalization");
MessageVector messages;
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
sourceClass.cpus.front(),
"projected host output source core id");
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
ownerClass->cpus.front(),
"projected host output target core id");
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
return failure();
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
return failure();
operand = appendReceive(state, *ownerClass, fragmentRecord->operandType, messages, fragmentRecord->loc);
if (sourceClass.isBatch && !ownerClass->isBatch) {
if (!emittedBatchForwarding.insert(sourceClass.id).second) {
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
if (!localized)
return ownerClass->op->emitError(
"projected host output fragment assembly is missing forwarded batch fragments");
operand = *localized;
} else {
SmallVector<ProducerKey, 8> forwardedKeys;
forwardedKeys.reserve(sourceClass.cpus.size());
Value forwardedPayload = fragmentRecord->operand;
for (PendingProjectedHostOutputFragment* candidate : fragments) {
if (candidate->sourceClass != sourceClass.id)
continue;
if (candidate->operand != forwardedPayload)
return ownerClass->op->emitError(
"projected host output batch forwarding expects one shared batch payload per source class");
forwardedKeys.push_back(candidate->producerKey);
}
llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) {
return lhs.instance.laneStart < rhs.instance.laneStart;
});
if (failed(emitClassToClassCommunication(
state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc)))
return failure();
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
if (!localized)
return ownerClass->op->emitError(
"projected host output fragment assembly failed to recover forwarded batch fragment");
operand = *localized;
}
} else {
MessageVector messages;
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
sourceClass.cpus.front(),
"projected host output source core id");
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
ownerClass->cpus.front(),
"projected host output target core id");
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
return failure();
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
return failure();
operand = appendReceive(state,
*ownerClass,
cast<RankedTensorType>(operand.getType()),
messages,
fragmentRecord->loc);
}
} else if (!ownerClass->isBatch) {
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
state,