This commit is contained in:
@@ -146,9 +146,7 @@ struct ProjectedBatchInputKeyInfo {
|
|||||||
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) {
|
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
|
||||||
return lhs == rhs;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ProjectedTransferDescriptor {
|
struct ProjectedTransferDescriptor {
|
||||||
@@ -201,6 +199,7 @@ struct MaterializerState {
|
|||||||
DenseSet<ClassSlotKey> materializedSlots;
|
DenseSet<ClassSlotKey> materializedSlots;
|
||||||
|
|
||||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||||
|
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers;
|
||||||
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
|
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
|
||||||
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
|
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
|
||||||
AvailableValueStore availableValues;
|
AvailableValueStore availableValues;
|
||||||
@@ -1187,8 +1186,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
|||||||
return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute");
|
return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute");
|
||||||
|
|
||||||
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
|
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
|
||||||
if (sourceClass == targetClass)
|
if (sourceClass == targetClass) {
|
||||||
|
state.sameClassConsumers[producerKey].insert(targetClass);
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
appendDestinationClass(state, producerKey, targetClass);
|
appendDestinationClass(state, producerKey, targetClass);
|
||||||
}
|
}
|
||||||
@@ -1342,7 +1343,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto& producerEntry : pending) {
|
for (auto& producerEntry : pending) {
|
||||||
ProducerKey producer = producerEntry.first;
|
ProducerKey producer = producerEntry.first;
|
||||||
for (auto& classEntry : producerEntry.second) {
|
for (auto& classEntry : producerEntry.second) {
|
||||||
ClassId targetClassId = classEntry.first;
|
ClassId targetClassId = classEntry.first;
|
||||||
PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second;
|
PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second;
|
||||||
@@ -1368,9 +1369,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
|
|
||||||
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
|
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
|
||||||
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
|
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
|
||||||
RankedTensorType payloadType = RankedTensorType::get(payloadShape,
|
RankedTensorType payloadType = RankedTensorType::get(
|
||||||
pendingDescriptor.fragmentType.getElementType(),
|
payloadShape, pendingDescriptor.fragmentType.getElementType(), pendingDescriptor.fragmentType.getEncoding());
|
||||||
pendingDescriptor.fragmentType.getEncoding());
|
|
||||||
|
|
||||||
ProjectedTransferDescriptor descriptor;
|
ProjectedTransferDescriptor descriptor;
|
||||||
descriptor.inputKey = pendingDescriptor.inputKey;
|
descriptor.inputKey = pendingDescriptor.inputKey;
|
||||||
@@ -1506,11 +1506,9 @@ Value buildProjectedPackedPayload(MaterializerState& state,
|
|||||||
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
|
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
|
||||||
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
||||||
|
|
||||||
Value sourceOffset =
|
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
||||||
createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
|
||||||
|
|
||||||
Value fragment =
|
Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
||||||
createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
|
||||||
|
|
||||||
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
|
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
|
||||||
scf::YieldOp::create(state.rewriter, loc, next);
|
scf::YieldOp::create(state.rewriter, loc, next);
|
||||||
@@ -1773,8 +1771,8 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
|||||||
appendProjectedScalarSendLoop(
|
appendProjectedScalarSendLoop(
|
||||||
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
|
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
Value received = appendReceive(
|
Value received =
|
||||||
state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
|
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
|
||||||
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
|
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
|
||||||
@@ -2512,8 +2510,10 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
|
|||||||
return appendInput(state, targetClass, input);
|
return appendInput(state, targetClass, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasProjectedInputReplacement(
|
bool hasProjectedInputReplacement(MaterializerState& state,
|
||||||
MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) {
|
SpatComputeBatch batch,
|
||||||
|
unsigned inputIndex,
|
||||||
|
ClassId classId) {
|
||||||
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
|
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
|
||||||
if (!extract)
|
if (!extract)
|
||||||
return false;
|
return false;
|
||||||
@@ -2965,8 +2965,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
|||||||
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||||
if (std::optional<ProjectedExtractReplacement> replacement =
|
if (std::optional<ProjectedExtractReplacement> replacement =
|
||||||
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
||||||
FailureOr<Value> projected = materializeProjectedExtractReplacement(
|
FailureOr<Value> projected =
|
||||||
state, targetClass, extract, *replacement, projectionSlotIndex);
|
materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, projectionSlotIndex);
|
||||||
if (failed(projected))
|
if (failed(projected))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -3288,26 +3288,8 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
|
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
|
||||||
for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) {
|
auto it = state.sameClassConsumers.find(producerKey);
|
||||||
auto cpuIt = state.schedule.computeToCpuMap.find(consumer);
|
return it != state.sameClassConsumers.end() && it->second.contains(classId);
|
||||||
if (cpuIt == state.schedule.computeToCpuMap.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (state.cpuToClass.lookup(cpuIt->second) != classId)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (Value input : getComputeInstanceInputs(consumer)) {
|
|
||||||
std::optional<ProducerKey> producer = getProducerKey(input, &consumer);
|
|
||||||
if (!producer)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (ProducerKey expanded : expandWholeBatchProducerKey(*producer))
|
|
||||||
if (expanded == producerKey)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool canCompactBatchClassRun(MaterializerState& state,
|
bool canCompactBatchClassRun(MaterializerState& state,
|
||||||
|
|||||||
Reference in New Issue
Block a user