speed fix with a simple cache
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-24 10:52:28 +02:00
parent f595cc6ffd
commit 48ca6bd28d
@@ -146,9 +146,7 @@ struct ProjectedBatchInputKeyInfo {
return llvm::hash_combine(key.consumerOp, key.inputIndex);
}
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) {
return lhs == rhs;
}
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; }
};
struct ProjectedTransferDescriptor {
@@ -201,6 +199,7 @@ struct MaterializerState {
DenseSet<ClassSlotKey> materializedSlots;
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers;
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
AvailableValueStore availableValues;
@@ -1187,8 +1186,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute");
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
if (sourceClass == targetClass)
if (sourceClass == targetClass) {
state.sameClassConsumers[producerKey].insert(targetClass);
continue;
}
appendDestinationClass(state, producerKey, targetClass);
}
@@ -1342,7 +1343,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
}
for (auto& producerEntry : pending) {
ProducerKey producer = producerEntry.first;
ProducerKey producer = producerEntry.first;
for (auto& classEntry : producerEntry.second) {
ClassId targetClassId = classEntry.first;
PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second;
@@ -1368,9 +1369,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
RankedTensorType payloadType = RankedTensorType::get(payloadShape,
pendingDescriptor.fragmentType.getElementType(),
pendingDescriptor.fragmentType.getEncoding());
RankedTensorType payloadType = RankedTensorType::get(
payloadShape, pendingDescriptor.fragmentType.getElementType(), pendingDescriptor.fragmentType.getEncoding());
ProjectedTransferDescriptor descriptor;
descriptor.inputKey = pendingDescriptor.inputKey;
@@ -1506,11 +1506,9 @@ Value buildProjectedPackedPayload(MaterializerState& state,
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
Value sourceOffset =
createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
Value fragment =
createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
scf::YieldOp::create(state.rewriter, loc, next);
@@ -1773,8 +1771,8 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
appendProjectedScalarSendLoop(
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendReceive(
state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
@@ -2512,8 +2510,10 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
return appendInput(state, targetClass, input);
}
bool hasProjectedInputReplacement(
MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) {
bool hasProjectedInputReplacement(MaterializerState& state,
SpatComputeBatch batch,
unsigned inputIndex,
ClassId classId) {
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
if (!extract)
return false;
@@ -2965,8 +2965,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
if (std::optional<ProjectedExtractReplacement> replacement =
lookupProjectedExtractReplacement(state, targetClass, extract)) {
FailureOr<Value> projected = materializeProjectedExtractReplacement(
state, targetClass, extract, *replacement, projectionSlotIndex);
FailureOr<Value> projected =
materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, projectionSlotIndex);
if (failed(projected))
return failure();
@@ -3288,26 +3288,8 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
}
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) {
auto cpuIt = state.schedule.computeToCpuMap.find(consumer);
if (cpuIt == state.schedule.computeToCpuMap.end())
continue;
if (state.cpuToClass.lookup(cpuIt->second) != classId)
continue;
for (Value input : getComputeInstanceInputs(consumer)) {
std::optional<ProducerKey> producer = getProducerKey(input, &consumer);
if (!producer)
continue;
for (ProducerKey expanded : expandWholeBatchProducerKey(*producer))
if (expanded == producerKey)
return true;
}
}
return false;
auto it = state.sameClassConsumers.find(producerKey);
return it != state.sameClassConsumers.end() && it->second.contains(classId);
}
bool canCompactBatchClassRun(MaterializerState& state,