fix high memory usage in IR

This commit is contained in:
NiccoloN
2026-05-24 10:41:47 +02:00
parent c734f1b37e
commit f595cc6ffd
@@ -124,6 +124,49 @@ struct BatchRunSendPlan {
SmallVector<int32_t, 16> targetCoreIds;
};
struct ProjectedBatchInputKey {
Operation* consumerOp = nullptr;
unsigned inputIndex = 0;
bool operator==(const ProjectedBatchInputKey& other) const {
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
}
};
struct ProjectedBatchInputKeyInfo {
static ProjectedBatchInputKey getEmptyKey() {
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
}
static ProjectedBatchInputKey getTombstoneKey() {
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
}
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
return llvm::hash_combine(key.consumerOp, key.inputIndex);
}
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) {
return lhs == rhs;
}
};
struct ProjectedTransferDescriptor {
ProjectedBatchInputKey inputKey;
Operation* extractOp = nullptr;
RankedTensorType fragmentType;
RankedTensorType payloadType;
unsigned fragmentsPerLane = 1;
SmallVector<int64_t, 16> laneMajorSourceDim0Offsets;
};
struct ProjectedExtractReplacement {
Value payload;
RankedTensorType fragmentType;
unsigned fragmentsPerLane = 1;
};
struct MaterializerState;
class AvailableValueStore {
@@ -158,6 +201,8 @@ struct MaterializerState {
DenseSet<ClassSlotKey> materializedSlots;
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
AvailableValueStore availableValues;
DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps;
@@ -1153,6 +1198,197 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
return success();
}
bool isValueOffset(OpFoldResult offset, Value expected) {
auto value = dyn_cast<Value>(offset);
return value && value == expected;
}
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
auto attr = dyn_cast<Attribute>(value);
if (!attr)
return false;
auto intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getInt() == expected;
}
std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) {
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
if (!inputArg || !laneArg)
return std::nullopt;
if (!inputArg->hasOneUse())
return std::nullopt;
Operation* user = *inputArg->getUsers().begin();
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extract || extract.getSource() != *inputArg)
return std::nullopt;
auto inputType = dyn_cast<RankedTensorType>(inputArg->getType());
auto fragmentType = dyn_cast<RankedTensorType>(extract.getResult().getType());
if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape())
return std::nullopt;
if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank())
return std::nullopt;
SmallVector<OpFoldResult, 4> offsets = extract.getMixedOffsets();
SmallVector<OpFoldResult, 4> sizes = extract.getMixedSizes();
SmallVector<OpFoldResult, 4> strides = extract.getMixedStrides();
if (offsets.size() != static_cast<size_t>(inputType.getRank())
|| sizes.size() != static_cast<size_t>(inputType.getRank())
|| strides.size() != static_cast<size_t>(inputType.getRank()))
return std::nullopt;
if (!isValueOffset(offsets.front(), *laneArg))
return std::nullopt;
if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1))
return std::nullopt;
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
if (!isStaticIndexAttr(offsets[dim], 0))
return std::nullopt;
if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim)))
return std::nullopt;
if (!isStaticIndexAttr(strides[dim], 1))
return std::nullopt;
}
if (fragmentType.getDimSize(0) != 1)
return std::nullopt;
for (int64_t dim = 1; dim < inputType.getRank(); ++dim)
if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim))
return std::nullopt;
return extract;
}
LogicalResult collectProjectedTransfers(MaterializerState& state) {
struct PendingProjectedTransferDescriptor {
ProjectedBatchInputKey inputKey;
Operation* extractOp = nullptr;
RankedTensorType fragmentType;
SmallVector<SmallVector<int64_t, 4>, 8> offsetsByLane;
bool invalid = false;
};
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) {
auto batch = dyn_cast<SpatComputeBatch>(consumer.op);
if (!batch || consumer.laneCount != 1)
continue;
auto cpuIt = state.schedule.computeToCpuMap.find(consumer);
if (cpuIt == state.schedule.computeToCpuMap.end())
return consumer.op->emitError("projected transfer collection expected scheduled consumer");
ClassId targetClassId = state.cpuToClass.lookup(cpuIt->second);
MaterializedClass& targetClass = state.classes[targetClassId];
if (!targetClass.isBatch)
continue;
auto targetLaneIt = targetClass.cpuToLane.find(cpuIt->second);
if (targetLaneIt == targetClass.cpuToLane.end())
return consumer.op->emitError("projected transfer collection could not recover target lane");
unsigned targetLane = targetLaneIt->second;
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
std::optional<ProducerKey> producer = getProducerKey(input, &consumer);
if (!producer)
continue;
auto producerCpuIt = state.schedule.computeToCpuMap.find(producer->instance);
if (producerCpuIt == state.schedule.computeToCpuMap.end())
continue;
ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second);
if (sourceClassId == targetClassId)
continue;
std::optional<tensor::ExtractSliceOp> extract =
matchSimpleLaneProjectedInput(batch, static_cast<unsigned>(inputIndex));
if (!extract)
continue;
auto fragmentType = cast<RankedTensorType>((*extract).getResult().getType());
PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId];
if (descriptor.offsetsByLane.empty()) {
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
descriptor.extractOp = extract->getOperation();
descriptor.fragmentType = fragmentType;
descriptor.offsetsByLane.resize(targetClass.cpus.size());
}
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation()
|| descriptor.fragmentType != fragmentType) {
descriptor.invalid = true;
continue;
}
if (targetLane >= descriptor.offsetsByLane.size()) {
descriptor.invalid = true;
continue;
}
descriptor.offsetsByLane[targetLane].push_back(static_cast<int64_t>(consumer.laneStart));
}
}
for (auto& producerEntry : pending) {
ProducerKey producer = producerEntry.first;
for (auto& classEntry : producerEntry.second) {
ClassId targetClassId = classEntry.first;
PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second;
if (pendingDescriptor.invalid)
continue;
if (pendingDescriptor.offsetsByLane.empty())
continue;
unsigned fragmentsPerLane = pendingDescriptor.offsetsByLane.front().size();
if (fragmentsPerLane == 0)
continue;
bool uniform = true;
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane) {
if (laneOffsets.size() != fragmentsPerLane) {
uniform = false;
break;
}
}
if (!uniform)
continue;
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
RankedTensorType payloadType = RankedTensorType::get(payloadShape,
pendingDescriptor.fragmentType.getElementType(),
pendingDescriptor.fragmentType.getEncoding());
ProjectedTransferDescriptor descriptor;
descriptor.inputKey = pendingDescriptor.inputKey;
descriptor.extractOp = pendingDescriptor.extractOp;
descriptor.fragmentType = pendingDescriptor.fragmentType;
descriptor.payloadType = payloadType;
descriptor.fragmentsPerLane = fragmentsPerLane;
descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane)
llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets);
state.projectedTransfers[producer][targetClassId] = std::move(descriptor);
}
}
return success();
}
SmallVector<ProducerKey, 8> getOutputKeysForPeers(ArrayRef<ComputeInstance> peers, size_t resultIndex) {
SmallVector<ProducerKey, 8> keys;
keys.reserve(peers.size());
@@ -1237,6 +1473,111 @@ void appendScalarSendLoop(MaterializerState& state,
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
}
Value buildProjectedPackedPayload(MaterializerState& state,
Operation* anchor,
Value fullPayload,
const ProjectedTransferDescriptor& descriptor,
Value laneIndex,
Location loc) {
assert(descriptor.fragmentsPerLane > 1 && "use direct fragment path for single-fragment projection");
Value init = tensor::EmptyOp::create(
state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType())
.getResult();
Value lowerBound = createIndexConstant(state, anchor, 0);
Value upperBound = createIndexConstant(state, anchor, descriptor.fragmentsPerLane);
Value step = createIndexConstant(state, anchor, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
Block* body = loop.getBody();
if (!body->empty())
if (auto yield = dyn_cast<scf::YieldOp>(body->back()))
state.rewriter.eraseOp(yield);
OpBuilder::InsertionGuard guard(state.rewriter);
state.rewriter.setInsertionPointToEnd(body);
Value fragmentIndex = loop.getInductionVar();
Value acc = body->getArgument(1);
Value fragmentsPerLane = createIndexConstant(state, anchor, descriptor.fragmentsPerLane);
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
Value sourceOffset =
createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
Value fragment =
createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
scf::YieldOp::create(state.rewriter, loc, next);
return loop.getResult(0);
}
void appendProjectedScalarSendLoop(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
const ProjectedTransferDescriptor& descriptor,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
assert(!sourceClass.isBatch && "projected scalar send expects scalar source class");
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size()
&& "projected send lane count mismatch");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (channelIds.size() == 1) {
Value channelId = createIndexConstant(state, sourceClass.op, channelIds.front());
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds.front());
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds.front());
Value laneIndex = createIndexConstant(state, sourceClass.op, 0);
Value sendPayload;
if (descriptor.fragmentsPerLane == 1) {
Value offset = createIndexConstant(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front());
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
}
else {
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc);
}
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
return;
}
Value lowerBound = createIndexConstant(state, sourceClass.op, 0);
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = createIndexConstant(state, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
OpBuilder::InsertionGuard guard(state.rewriter);
state.rewriter.setInsertionPointToStart(loop.getBody());
Value index = loop.getInductionVar();
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
Value sendPayload;
if (descriptor.fragmentsPerLane == 1) {
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc);
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
}
else {
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc);
}
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
}
void appendSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
@@ -1394,6 +1735,7 @@ SmallVector<ClassId, 4> collectDestinationClassesForKeys(MaterializerState& stat
SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState& state,
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
ArrayRef<ClassId> destinationClasses,
Value payload,
Location loc) {
@@ -1401,25 +1743,44 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
size_t messageCount = 0;
for (ClassId destinationClass : destinationClasses) {
if (destinationClass == sourceClass.id)
continue;
MaterializedClass& targetClass = state.classes[destinationClass];
messageCount += targetClass.isBatch ? targetClass.cpus.size() : 1;
}
SmallVector<int64_t, 8> allChannelIds;
SmallVector<int32_t, 8> allSourceCoreIds;
SmallVector<int32_t, 8> allTargetCoreIds;
allChannelIds.reserve(messageCount);
allSourceCoreIds.reserve(messageCount);
allTargetCoreIds.reserve(messageCount);
SmallVector<ScalarSourceReceivePlan, 4> receivePlans;
receivePlans.reserve(destinationClasses.size());
const auto tryEmitProjected = [&](ClassId destinationClass,
const SmallVector<int64_t, 8>& channelIds,
const SmallVector<int32_t, 8>& sourceCoreIds,
const SmallVector<int32_t, 8>& targetCoreIds) -> bool {
if (keys.size() != 1)
return false;
MaterializedClass& targetClass = state.classes[destinationClass];
if (!targetClass.isBatch)
return false;
auto producerIt = state.projectedTransfers.find(keys.front());
if (producerIt == state.projectedTransfers.end())
return false;
auto descriptorIt = producerIt->second.find(destinationClass);
if (descriptorIt == producerIt->second.end())
return false;
const ProjectedTransferDescriptor& descriptor = descriptorIt->second;
if (descriptor.laneMajorSourceDim0Offsets.size()
!= targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane))
return false;
appendProjectedScalarSendLoop(
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendReceive(
state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
return true;
};
for (ClassId destinationClass : destinationClasses) {
if (destinationClass == sourceClass.id)
continue;
@@ -1435,10 +1796,6 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
plan.channelIds.push_back(channelId);
plan.sourceCoreIds.push_back(sourceCpu);
plan.targetCoreIds.push_back(targetCpu);
allChannelIds.push_back(channelId);
allSourceCoreIds.push_back(sourceCpu);
allTargetCoreIds.push_back(targetCpu);
};
if (!targetClass.isBatch)
@@ -1447,12 +1804,13 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
for (CpuId targetCpu : targetClass.cpus)
appendMessage(static_cast<int32_t>(targetCpu));
if (tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds))
continue;
appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc);
receivePlans.push_back(std::move(plan));
}
if (!allChannelIds.empty())
appendSend(state, sourceClass, payload, allChannelIds, allSourceCoreIds, allTargetCoreIds, loc);
return receivePlans;
}
@@ -1465,7 +1823,7 @@ LogicalResult emitScalarSourceCommunication(
SmallVector<ClassId, 4> destinationClasses = collectDestinationClassesForKeys(state, keys);
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
emitScalarSourceSends(state, sourceClass, destinationClasses, payload, loc);
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc);
for (const ScalarSourceReceivePlan& plan : receivePlans) {
MaterializedClass& targetClass = state.classes[plan.targetClass];
@@ -1716,7 +2074,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
MaterializedClass& targetClass,
const ComputeInstance& instance,
Value laneValue,
ArrayRef<size_t> resultIndices);
ArrayRef<size_t> resultIndices,
std::optional<Value> projectionSlotIndex = std::nullopt);
FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState& state,
MaterializedClass& targetClass,
@@ -1766,7 +2125,7 @@ FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState&
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices);
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
if (failed(produced))
return failure();
@@ -1833,7 +2192,7 @@ FailureOr<Value> insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices);
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
if (failed(produced))
return failure();
@@ -2153,6 +2512,19 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
return appendInput(state, targetClass, input);
}
bool hasProjectedInputReplacement(
MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) {
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
if (!extract)
return false;
auto replacementIt = state.projectedExtractReplacements.find(extract->getOperation());
if (replacementIt == state.projectedExtractReplacements.end())
return false;
return replacementIt->second.find(classId) != replacementIt->second.end();
}
void mapWeights(MaterializerState& state,
MaterializedClass& targetClass,
const ComputeInstance& instance,
@@ -2195,6 +2567,9 @@ LogicalResult mapInputs(MaterializerState& state,
auto batch = cast<SpatComputeBatch>(op);
for (auto [index, input] : llvm::enumerate(batch.getInputs())) {
if (hasProjectedInputReplacement(state, batch, static_cast<unsigned>(index), targetClass.id))
continue;
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
if (failed(mapped))
return batch.emitOpError("failed to resolve materialized compute_batch input");
@@ -2262,6 +2637,35 @@ SmallVector<Type, 4> collectBatchOutputFragmentTypes(SpatComputeBatch batch) {
return types;
}
std::optional<ProjectedExtractReplacement> lookupProjectedExtractReplacement(MaterializerState& state,
MaterializedClass& targetClass,
tensor::ExtractSliceOp extract) {
auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation());
if (replacementIt == state.projectedExtractReplacements.end())
return std::nullopt;
auto classIt = replacementIt->second.find(targetClass.id);
if (classIt == replacementIt->second.end())
return std::nullopt;
return classIt->second;
}
FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state,
MaterializedClass& targetClass,
tensor::ExtractSliceOp extract,
const ProjectedExtractReplacement& replacement,
std::optional<Value> projectionSlotIndex) {
if (replacement.fragmentsPerLane == 1)
return replacement.payload;
if (!projectionSlotIndex)
return targetClass.op->emitError("packed projected extract replacement requires a projection slot index");
return createDim0ExtractSlice(
state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0));
}
FailureOr<SmallVector<Value, 4>>
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
assert(!peers.empty() && "expected at least one peer instance");
@@ -2301,6 +2705,19 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
for (Operation& op : sourceBlock.without_terminator()) {
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
if (std::optional<ProjectedExtractReplacement> replacement =
lookupProjectedExtractReplacement(state, targetClass, extract)) {
FailureOr<Value> projected =
materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, std::nullopt);
if (failed(projected))
return failure();
mapper.map(extract.getResult(), *projected);
continue;
}
}
Operation* cloned = state.rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(oldResult, newResult);
@@ -2503,7 +2920,7 @@ LogicalResult emitPackedRunFanout(MaterializerState& state,
assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class");
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
emitScalarSourceSends(state, sourceClass, destinationClasses, packed, loc);
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc);
for (const ScalarSourceReceivePlan& plan : receivePlans) {
MaterializedClass& targetClass = state.classes[plan.targetClass];
@@ -2522,7 +2939,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
MaterializedClass& targetClass,
const ComputeInstance& instance,
Value laneValue,
ArrayRef<size_t> resultIndices) {
ArrayRef<size_t> resultIndices,
std::optional<Value> projectionSlotIndex) {
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
if (!batch)
return failure();
@@ -2544,6 +2962,19 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
for (Operation& op : sourceBlock.without_terminator()) {
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
if (std::optional<ProjectedExtractReplacement> replacement =
lookupProjectedExtractReplacement(state, targetClass, extract)) {
FailureOr<Value> projected = materializeProjectedExtractReplacement(
state, targetClass, extract, *replacement, projectionSlotIndex);
if (failed(projected))
return failure();
mapper.map(extract.getResult(), *projected);
continue;
}
}
Operation* cloned = state.rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(oldResult, newResult);
@@ -2637,7 +3068,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices);
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex);
if (failed(produced))
return failure();
@@ -3127,7 +3558,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc);
FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices);
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex);
if (failed(produced))
return failure();
@@ -3286,6 +3717,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
createEmptyMaterializedOps(state);
if (failed(collectProducerDestinations(state)))
return failure();
if (failed(collectProjectedTransfers(state)))
return failure();
for (const ComputeInstance& instance : schedule.dominanceOrderCompute)
if (failed(materializeInstanceSlot(state, instance)))