This commit is contained in:
+250
-27
@@ -306,10 +306,12 @@ struct ProjectedExtractReplacement {
|
||||
struct PendingProjectedHostOutputFragment {
|
||||
Value originalOutput;
|
||||
ClassId sourceClass = 0;
|
||||
ProducerKey producerKey;
|
||||
Value operand;
|
||||
RankedTensorType operandType;
|
||||
RankedTensorType fragmentType;
|
||||
int64_t packedFragmentIndex = -1;
|
||||
int64_t currentLane = -1;
|
||||
SmallVector<int64_t, 4> offsets;
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
@@ -1137,6 +1139,59 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
|
||||
Block& funcBlock = state.func.getBody().front();
|
||||
for (Operation& op : funcBlock) {
|
||||
if (state.oldComputeOps.contains(&op)) {
|
||||
state.rewriter.setInsertionPoint(&op);
|
||||
return;
|
||||
}
|
||||
}
|
||||
state.rewriter.setInsertionPointToEnd(&funcBlock);
|
||||
}
|
||||
|
||||
FailureOr<ClassId> createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) {
|
||||
DenseSet<CpuId> usedCpus;
|
||||
for (const auto& [cpu, _] : state.cpuToClass)
|
||||
usedCpus.insert(cpu);
|
||||
|
||||
CpuId assemblyCpu = 0;
|
||||
while (usedCpus.contains(assemblyCpu))
|
||||
++assemblyCpu;
|
||||
|
||||
setInsertionPointForNewMaterializedOp(state);
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return state.func.emitError("projected host assembly class requires a static ranked tensor output");
|
||||
|
||||
auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {});
|
||||
compute.getProperties().setOperandSegmentSizes({0, 0});
|
||||
auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id");
|
||||
if (failed(coreIdAttr))
|
||||
return failure();
|
||||
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
|
||||
|
||||
Block* body = state.rewriter.createBlock(&compute.getBody());
|
||||
state.rewriter.setInsertionPointToEnd(body);
|
||||
Value placeholder =
|
||||
tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
||||
SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder});
|
||||
state.rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
MaterializedClass materializedClass;
|
||||
materializedClass.id = state.classes.size();
|
||||
materializedClass.cpus.push_back(assemblyCpu);
|
||||
materializedClass.op = compute.getOperation();
|
||||
materializedClass.body = body;
|
||||
materializedClass.hostOutputToResultIndex[originalOutput] = 0;
|
||||
materializedClass.hostOutputs.push_back(originalOutput);
|
||||
state.cpuToClass[assemblyCpu] = materializedClass.id;
|
||||
state.hostOutputOwners[originalOutput] = materializedClass.id;
|
||||
state.classes.push_back(std::move(materializedClass));
|
||||
return state.classes.back().id;
|
||||
}
|
||||
|
||||
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
|
||||
auto it = materializedClass.weightArgs.find(weight);
|
||||
if (it != materializedClass.weightArgs.end())
|
||||
@@ -1897,6 +1952,14 @@ FailureOr<SmallVector<OpFoldResult, 4>> buildProjectedFragmentOffsetsInClass(Mat
|
||||
return fragmentOffsets;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
Value createDim0InsertSlice(
|
||||
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
|
||||
auto fragmentType = cast<RankedTensorType>(fragment.getType());
|
||||
@@ -3639,6 +3702,9 @@ LogicalResult appendSend(MaterializerState& state,
|
||||
|
||||
if (sourceClass.isBatch) {
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
if (messages.size() != sourceClass.cpus.size())
|
||||
return sourceClass.op->emitError("batch send expects exactly one message per materialized lane")
|
||||
<< " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size();
|
||||
|
||||
Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc);
|
||||
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc);
|
||||
@@ -3686,6 +3752,11 @@ Value appendReceive(
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
|
||||
if (targetClass.isBatch) {
|
||||
if (messages.size() != targetClass.cpus.size()) {
|
||||
targetClass.op->emitOpError("batch receive expects exactly one message per materialized lane")
|
||||
<< " messageCount=" << messages.size() << " laneCount=" << targetClass.cpus.size();
|
||||
return Value();
|
||||
}
|
||||
Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc);
|
||||
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc);
|
||||
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc);
|
||||
@@ -5481,10 +5552,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
|
||||
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
||||
originalOutput,
|
||||
sourceClass.id,
|
||||
ProducerKey {peer, resultIndex},
|
||||
packed,
|
||||
cast<RankedTensorType>(packed.getType()),
|
||||
fragmentType,
|
||||
static_cast<int64_t>(runIndex),
|
||||
static_cast<int64_t>(runIndex),
|
||||
SmallVector<int64_t, 4>(*offsets),
|
||||
SmallVector<int64_t, 4>(*sizes),
|
||||
SmallVector<int64_t, 4>(*strides),
|
||||
@@ -5572,10 +5645,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
||||
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
||||
originalOutput,
|
||||
sourceClass.id,
|
||||
key,
|
||||
packed,
|
||||
packedType,
|
||||
fragmentType,
|
||||
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
|
||||
static_cast<int64_t>(fragmentIndex),
|
||||
SmallVector<int64_t, 4>(*offsets),
|
||||
SmallVector<int64_t, 4>(*sizes),
|
||||
SmallVector<int64_t, 4>(*strides),
|
||||
@@ -5611,16 +5686,6 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
}
|
||||
|
||||
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
|
||||
if (ownerClass->isBatch) {
|
||||
auto scalarOwnerIt = llvm::find_if(state.classes, [](const MaterializedClass& candidate) {
|
||||
return !candidate.isBatch;
|
||||
});
|
||||
if (scalarOwnerIt == state.classes.end())
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output finalization requires a scalar assembly class when the preferred host owner is batch");
|
||||
ownerClass = &*scalarOwnerIt;
|
||||
state.hostOutputOwners[originalOutput] = ownerClass->id;
|
||||
}
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
@@ -5646,6 +5711,119 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
if (allFromSameSourceClass) {
|
||||
ownerClass = &state.classes[fragments.front()->sourceClass];
|
||||
state.hostOutputOwners[originalOutput] = ownerClass->id;
|
||||
} else {
|
||||
if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput))
|
||||
goto owner_selected;
|
||||
FailureOr<ClassId> createdOwner =
|
||||
createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc);
|
||||
if (failed(createdOwner))
|
||||
return failure();
|
||||
ownerClass = &state.classes[*createdOwner];
|
||||
}
|
||||
owner_selected:
|
||||
|
||||
if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) {
|
||||
auto sourceBatch = dyn_cast<SpatComputeBatch>(fragments.front()->producerKey.instance.op);
|
||||
auto batch = dyn_cast<SpatScheduledComputeBatch>(ownerClass->op);
|
||||
auto inParallelOp = dyn_cast_or_null<SpatInParallelOp>(ownerClass->body->getTerminator());
|
||||
auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput);
|
||||
if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end())
|
||||
return ownerClass->op->emitError("missing batch host assembly state for projected host output");
|
||||
FailureOr<tensor::ParallelInsertSliceOp> sourceProjection =
|
||||
getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex);
|
||||
std::optional<BlockArgument> sourceLaneArg = sourceBatch.getLaneArgument();
|
||||
if (failed(sourceProjection) || !sourceLaneArg)
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly requires the source batch projection metadata");
|
||||
|
||||
auto outputArg = batch.getOutputArgument(resultIt->second);
|
||||
auto laneArg = batch.getLaneArgument();
|
||||
if (!outputArg || !laneArg)
|
||||
return ownerClass->op->emitError("missing compute_batch output block argument for projected host output");
|
||||
|
||||
if (fragments.size() != ownerClass->cpus.size())
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly expects exactly one fragment per materialized lane");
|
||||
|
||||
SmallVector<PendingProjectedHostOutputFragment*, 8> fragmentsByLane(ownerClass->cpus.size(), nullptr);
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane;
|
||||
if (currentLane < 0 || currentLane >= static_cast<int64_t>(fragmentsByLane.size()))
|
||||
return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds");
|
||||
if (fragmentsByLane[currentLane])
|
||||
return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane");
|
||||
fragmentsByLane[currentLane] = fragmentRecord;
|
||||
}
|
||||
|
||||
if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; }))
|
||||
return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes");
|
||||
|
||||
FailureOr<SmallVector<int64_t, 4>> firstSizes =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
|
||||
FailureOr<SmallVector<int64_t, 4>> firstStrides =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
|
||||
if (failed(firstSizes) || failed(firstStrides))
|
||||
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
|
||||
SmallVector<int64_t, 4> referenceSizes(*firstSizes);
|
||||
SmallVector<int64_t, 4> referenceStrides(*firstStrides);
|
||||
Value laneOperand;
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
|
||||
FailureOr<SmallVector<int64_t, 4>> fragmentSizes =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||
FailureOr<SmallVector<int64_t, 4>> fragmentStrides =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||
if (failed(fragmentSizes) || failed(fragmentStrides))
|
||||
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
|
||||
if (SmallVector<int64_t, 4>(*fragmentSizes) != referenceSizes
|
||||
|| SmallVector<int64_t, 4>(*fragmentStrides) != referenceStrides)
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly expects a uniform fragment shape and strides");
|
||||
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
Value operand;
|
||||
if (std::optional<Value> availableValue =
|
||||
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
|
||||
operand = *availableValue;
|
||||
} else {
|
||||
operand = fragmentRecord->operand;
|
||||
}
|
||||
if (!isValueLegalInMaterializedClassBody(operand, *ownerClass))
|
||||
return ownerClass->op->emitError(
|
||||
"projected batch host output assembly requires source-local fragment operands");
|
||||
if (laneOperand && laneOperand != operand)
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly expects one shared lane-local fragment producer");
|
||||
laneOperand = operand;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> mixedOffsets;
|
||||
mixedOffsets.reserve(referenceSizes.size());
|
||||
for (size_t dim = 0; dim < referenceSizes.size(); ++dim) {
|
||||
SmallVector<int64_t, 8> offsetsByLane;
|
||||
offsetsByLane.reserve(fragmentsByLane.size());
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
|
||||
FailureOr<SmallVector<int64_t, 4>> fragmentOffsets =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||
if (failed(fragmentOffsets))
|
||||
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets");
|
||||
offsetsByLane.push_back((*fragmentOffsets)[dim]);
|
||||
}
|
||||
mixedOffsets.push_back(allEqual(offsetsByLane)
|
||||
? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front()))
|
||||
: OpFoldResult(createLaneIndexedIndexValue(
|
||||
state, *ownerClass, ArrayRef<int64_t>(offsetsByLane), fragments.front()->loc)));
|
||||
}
|
||||
|
||||
state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second);
|
||||
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
tensor::ParallelInsertSliceOp::create(state.rewriter,
|
||||
fragments.front()->loc,
|
||||
laneOperand,
|
||||
*outputArg,
|
||||
mixedOffsets,
|
||||
getStaticIndexAttrs(state.rewriter, referenceSizes),
|
||||
getStaticIndexAttrs(state.rewriter, referenceStrides));
|
||||
continue;
|
||||
}
|
||||
|
||||
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
|
||||
@@ -5656,28 +5834,73 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
SmallVector<int64_t, 64> flatSizes;
|
||||
SmallVector<int64_t, 64> flatStrides;
|
||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||
DenseSet<ClassId> emittedBatchForwarding;
|
||||
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
Value operand = fragmentRecord->operand;
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
Value operand;
|
||||
|
||||
if (std::optional<Value> availableValue =
|
||||
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
|
||||
operand = *availableValue;
|
||||
} else if (fragmentRecord->sourceClass == sourceClass.id) {
|
||||
operand = fragmentRecord->operand;
|
||||
} else {
|
||||
return sourceClass.op->emitError(
|
||||
"projected host output fragment assembly is missing source-visible fragment operands before finalization");
|
||||
}
|
||||
|
||||
if (fragmentRecord->sourceClass != ownerClass->id) {
|
||||
if (sourceClass.isBatch || ownerClass->isBatch)
|
||||
return sourceClass.op->emitError(
|
||||
"projected host output fragment assembly requires scalarized cross-class operands before finalization");
|
||||
MessageVector messages;
|
||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
|
||||
sourceClass.cpus.front(),
|
||||
"projected host output source core id");
|
||||
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
|
||||
ownerClass->cpus.front(),
|
||||
"projected host output target core id");
|
||||
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
|
||||
return failure();
|
||||
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
||||
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
|
||||
return failure();
|
||||
operand = appendReceive(state, *ownerClass, fragmentRecord->operandType, messages, fragmentRecord->loc);
|
||||
if (sourceClass.isBatch && !ownerClass->isBatch) {
|
||||
if (!emittedBatchForwarding.insert(sourceClass.id).second) {
|
||||
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
|
||||
if (!localized)
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output fragment assembly is missing forwarded batch fragments");
|
||||
operand = *localized;
|
||||
} else {
|
||||
SmallVector<ProducerKey, 8> forwardedKeys;
|
||||
forwardedKeys.reserve(sourceClass.cpus.size());
|
||||
Value forwardedPayload = fragmentRecord->operand;
|
||||
for (PendingProjectedHostOutputFragment* candidate : fragments) {
|
||||
if (candidate->sourceClass != sourceClass.id)
|
||||
continue;
|
||||
if (candidate->operand != forwardedPayload)
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output batch forwarding expects one shared batch payload per source class");
|
||||
forwardedKeys.push_back(candidate->producerKey);
|
||||
}
|
||||
llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) {
|
||||
return lhs.instance.laneStart < rhs.instance.laneStart;
|
||||
});
|
||||
if (failed(emitClassToClassCommunication(
|
||||
state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc)))
|
||||
return failure();
|
||||
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
|
||||
if (!localized)
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output fragment assembly failed to recover forwarded batch fragment");
|
||||
operand = *localized;
|
||||
}
|
||||
} else {
|
||||
MessageVector messages;
|
||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
|
||||
sourceClass.cpus.front(),
|
||||
"projected host output source core id");
|
||||
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
|
||||
ownerClass->cpus.front(),
|
||||
"projected host output target core id");
|
||||
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
|
||||
return failure();
|
||||
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
||||
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
|
||||
return failure();
|
||||
operand = appendReceive(state,
|
||||
*ownerClass,
|
||||
cast<RankedTensorType>(operand.getType()),
|
||||
messages,
|
||||
fragmentRecord->loc);
|
||||
}
|
||||
} else if (!ownerClass->isBatch) {
|
||||
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
|
||||
state,
|
||||
|
||||
Reference in New Issue
Block a user