This commit is contained in:
@@ -146,9 +146,7 @@ struct ProjectedBatchInputKeyInfo {
|
||||
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
struct ProjectedTransferDescriptor {
|
||||
@@ -201,6 +199,7 @@ struct MaterializerState {
|
||||
DenseSet<ClassSlotKey> materializedSlots;
|
||||
|
||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers;
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
|
||||
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
|
||||
AvailableValueStore availableValues;
|
||||
@@ -1187,8 +1186,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
||||
return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute");
|
||||
|
||||
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
|
||||
if (sourceClass == targetClass)
|
||||
if (sourceClass == targetClass) {
|
||||
state.sameClassConsumers[producerKey].insert(targetClass);
|
||||
continue;
|
||||
}
|
||||
|
||||
appendDestinationClass(state, producerKey, targetClass);
|
||||
}
|
||||
@@ -1368,9 +1369,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
|
||||
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
|
||||
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
|
||||
RankedTensorType payloadType = RankedTensorType::get(payloadShape,
|
||||
pendingDescriptor.fragmentType.getElementType(),
|
||||
pendingDescriptor.fragmentType.getEncoding());
|
||||
RankedTensorType payloadType = RankedTensorType::get(
|
||||
payloadShape, pendingDescriptor.fragmentType.getElementType(), pendingDescriptor.fragmentType.getEncoding());
|
||||
|
||||
ProjectedTransferDescriptor descriptor;
|
||||
descriptor.inputKey = pendingDescriptor.inputKey;
|
||||
@@ -1506,11 +1506,9 @@ Value buildProjectedPackedPayload(MaterializerState& state,
|
||||
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
|
||||
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
||||
|
||||
Value sourceOffset =
|
||||
createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
||||
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
||||
|
||||
Value fragment =
|
||||
createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
||||
Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
||||
|
||||
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
|
||||
scf::YieldOp::create(state.rewriter, loc, next);
|
||||
@@ -1773,8 +1771,8 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
||||
appendProjectedScalarSendLoop(
|
||||
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
|
||||
Value received = appendReceive(
|
||||
state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
Value received =
|
||||
appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
|
||||
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
|
||||
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
|
||||
@@ -2512,8 +2510,10 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
|
||||
return appendInput(state, targetClass, input);
|
||||
}
|
||||
|
||||
bool hasProjectedInputReplacement(
|
||||
MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) {
|
||||
bool hasProjectedInputReplacement(MaterializerState& state,
|
||||
SpatComputeBatch batch,
|
||||
unsigned inputIndex,
|
||||
ClassId classId) {
|
||||
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
|
||||
if (!extract)
|
||||
return false;
|
||||
@@ -2965,8 +2965,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
||||
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||
if (std::optional<ProjectedExtractReplacement> replacement =
|
||||
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
||||
FailureOr<Value> projected = materializeProjectedExtractReplacement(
|
||||
state, targetClass, extract, *replacement, projectionSlotIndex);
|
||||
FailureOr<Value> projected =
|
||||
materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, projectionSlotIndex);
|
||||
if (failed(projected))
|
||||
return failure();
|
||||
|
||||
@@ -3288,26 +3288,8 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
|
||||
}
|
||||
|
||||
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
|
||||
for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) {
|
||||
auto cpuIt = state.schedule.computeToCpuMap.find(consumer);
|
||||
if (cpuIt == state.schedule.computeToCpuMap.end())
|
||||
continue;
|
||||
|
||||
if (state.cpuToClass.lookup(cpuIt->second) != classId)
|
||||
continue;
|
||||
|
||||
for (Value input : getComputeInstanceInputs(consumer)) {
|
||||
std::optional<ProducerKey> producer = getProducerKey(input, &consumer);
|
||||
if (!producer)
|
||||
continue;
|
||||
|
||||
for (ProducerKey expanded : expandWholeBatchProducerKey(*producer))
|
||||
if (expanded == producerKey)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
auto it = state.sameClassConsumers.find(producerKey);
|
||||
return it != state.sameClassConsumers.end() && it->second.contains(classId);
|
||||
}
|
||||
|
||||
bool canCompactBatchClassRun(MaterializerState& state,
|
||||
|
||||
Reference in New Issue
Block a user