fix high memory usage in IR
This commit is contained in:
+464
-31
@@ -124,6 +124,49 @@ struct BatchRunSendPlan {
|
||||
SmallVector<int32_t, 16> targetCoreIds;
|
||||
};
|
||||
|
||||
struct ProjectedBatchInputKey {
|
||||
Operation* consumerOp = nullptr;
|
||||
unsigned inputIndex = 0;
|
||||
|
||||
bool operator==(const ProjectedBatchInputKey& other) const {
|
||||
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProjectedBatchInputKeyInfo {
|
||||
static ProjectedBatchInputKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static ProjectedBatchInputKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
|
||||
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||
}
|
||||
|
||||
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
struct ProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
Operation* extractOp = nullptr;
|
||||
|
||||
RankedTensorType fragmentType;
|
||||
RankedTensorType payloadType;
|
||||
unsigned fragmentsPerLane = 1;
|
||||
SmallVector<int64_t, 16> laneMajorSourceDim0Offsets;
|
||||
};
|
||||
|
||||
struct ProjectedExtractReplacement {
|
||||
Value payload;
|
||||
RankedTensorType fragmentType;
|
||||
unsigned fragmentsPerLane = 1;
|
||||
};
|
||||
|
||||
struct MaterializerState;
|
||||
|
||||
class AvailableValueStore {
|
||||
@@ -158,6 +201,8 @@ struct MaterializerState {
|
||||
DenseSet<ClassSlotKey> materializedSlots;
|
||||
|
||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
|
||||
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
|
||||
AvailableValueStore availableValues;
|
||||
DenseMap<Value, Value> hostReplacements;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
@@ -1153,6 +1198,197 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
bool isValueOffset(OpFoldResult offset, Value expected) {
|
||||
auto value = dyn_cast<Value>(offset);
|
||||
return value && value == expected;
|
||||
}
|
||||
|
||||
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
|
||||
auto attr = dyn_cast<Attribute>(value);
|
||||
if (!attr)
|
||||
return false;
|
||||
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
return intAttr && intAttr.getInt() == expected;
|
||||
}
|
||||
|
||||
std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) {
|
||||
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
|
||||
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
|
||||
if (!inputArg || !laneArg)
|
||||
return std::nullopt;
|
||||
|
||||
if (!inputArg->hasOneUse())
|
||||
return std::nullopt;
|
||||
|
||||
Operation* user = *inputArg->getUsers().begin();
|
||||
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||
if (!extract || extract.getSource() != *inputArg)
|
||||
return std::nullopt;
|
||||
|
||||
auto inputType = dyn_cast<RankedTensorType>(inputArg->getType());
|
||||
auto fragmentType = dyn_cast<RankedTensorType>(extract.getResult().getType());
|
||||
if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape())
|
||||
return std::nullopt;
|
||||
|
||||
if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank())
|
||||
return std::nullopt;
|
||||
|
||||
SmallVector<OpFoldResult, 4> offsets = extract.getMixedOffsets();
|
||||
SmallVector<OpFoldResult, 4> sizes = extract.getMixedSizes();
|
||||
SmallVector<OpFoldResult, 4> strides = extract.getMixedStrides();
|
||||
|
||||
if (offsets.size() != static_cast<size_t>(inputType.getRank())
|
||||
|| sizes.size() != static_cast<size_t>(inputType.getRank())
|
||||
|| strides.size() != static_cast<size_t>(inputType.getRank()))
|
||||
return std::nullopt;
|
||||
|
||||
if (!isValueOffset(offsets.front(), *laneArg))
|
||||
return std::nullopt;
|
||||
if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1))
|
||||
return std::nullopt;
|
||||
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
if (!isStaticIndexAttr(offsets[dim], 0))
|
||||
return std::nullopt;
|
||||
if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim)))
|
||||
return std::nullopt;
|
||||
if (!isStaticIndexAttr(strides[dim], 1))
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (fragmentType.getDimSize(0) != 1)
|
||||
return std::nullopt;
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim)
|
||||
if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim))
|
||||
return std::nullopt;
|
||||
|
||||
return extract;
|
||||
}
|
||||
|
||||
LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
struct PendingProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
Operation* extractOp = nullptr;
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<SmallVector<int64_t, 4>, 8> offsetsByLane;
|
||||
bool invalid = false;
|
||||
};
|
||||
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
|
||||
|
||||
for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(consumer.op);
|
||||
if (!batch || consumer.laneCount != 1)
|
||||
continue;
|
||||
|
||||
auto cpuIt = state.schedule.computeToCpuMap.find(consumer);
|
||||
if (cpuIt == state.schedule.computeToCpuMap.end())
|
||||
return consumer.op->emitError("projected transfer collection expected scheduled consumer");
|
||||
|
||||
ClassId targetClassId = state.cpuToClass.lookup(cpuIt->second);
|
||||
MaterializedClass& targetClass = state.classes[targetClassId];
|
||||
if (!targetClass.isBatch)
|
||||
continue;
|
||||
|
||||
auto targetLaneIt = targetClass.cpuToLane.find(cpuIt->second);
|
||||
if (targetLaneIt == targetClass.cpuToLane.end())
|
||||
return consumer.op->emitError("projected transfer collection could not recover target lane");
|
||||
|
||||
unsigned targetLane = targetLaneIt->second;
|
||||
|
||||
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
|
||||
std::optional<ProducerKey> producer = getProducerKey(input, &consumer);
|
||||
if (!producer)
|
||||
continue;
|
||||
|
||||
auto producerCpuIt = state.schedule.computeToCpuMap.find(producer->instance);
|
||||
if (producerCpuIt == state.schedule.computeToCpuMap.end())
|
||||
continue;
|
||||
|
||||
ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second);
|
||||
if (sourceClassId == targetClassId)
|
||||
continue;
|
||||
|
||||
std::optional<tensor::ExtractSliceOp> extract =
|
||||
matchSimpleLaneProjectedInput(batch, static_cast<unsigned>(inputIndex));
|
||||
if (!extract)
|
||||
continue;
|
||||
|
||||
auto fragmentType = cast<RankedTensorType>((*extract).getResult().getType());
|
||||
|
||||
PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId];
|
||||
if (descriptor.offsetsByLane.empty()) {
|
||||
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
descriptor.extractOp = extract->getOperation();
|
||||
descriptor.fragmentType = fragmentType;
|
||||
descriptor.offsetsByLane.resize(targetClass.cpus.size());
|
||||
}
|
||||
|
||||
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation()
|
||||
|| descriptor.fragmentType != fragmentType) {
|
||||
descriptor.invalid = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (targetLane >= descriptor.offsetsByLane.size()) {
|
||||
descriptor.invalid = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
descriptor.offsetsByLane[targetLane].push_back(static_cast<int64_t>(consumer.laneStart));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& producerEntry : pending) {
|
||||
ProducerKey producer = producerEntry.first;
|
||||
for (auto& classEntry : producerEntry.second) {
|
||||
ClassId targetClassId = classEntry.first;
|
||||
PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second;
|
||||
|
||||
if (pendingDescriptor.invalid)
|
||||
continue;
|
||||
if (pendingDescriptor.offsetsByLane.empty())
|
||||
continue;
|
||||
|
||||
unsigned fragmentsPerLane = pendingDescriptor.offsetsByLane.front().size();
|
||||
if (fragmentsPerLane == 0)
|
||||
continue;
|
||||
|
||||
bool uniform = true;
|
||||
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane) {
|
||||
if (laneOffsets.size() != fragmentsPerLane) {
|
||||
uniform = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!uniform)
|
||||
continue;
|
||||
|
||||
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
|
||||
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
|
||||
RankedTensorType payloadType = RankedTensorType::get(payloadShape,
|
||||
pendingDescriptor.fragmentType.getElementType(),
|
||||
pendingDescriptor.fragmentType.getEncoding());
|
||||
|
||||
ProjectedTransferDescriptor descriptor;
|
||||
descriptor.inputKey = pendingDescriptor.inputKey;
|
||||
descriptor.extractOp = pendingDescriptor.extractOp;
|
||||
descriptor.fragmentType = pendingDescriptor.fragmentType;
|
||||
descriptor.payloadType = payloadType;
|
||||
descriptor.fragmentsPerLane = fragmentsPerLane;
|
||||
descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
|
||||
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane)
|
||||
llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets);
|
||||
|
||||
state.projectedTransfers[producer][targetClassId] = std::move(descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ProducerKey, 8> getOutputKeysForPeers(ArrayRef<ComputeInstance> peers, size_t resultIndex) {
|
||||
SmallVector<ProducerKey, 8> keys;
|
||||
keys.reserve(peers.size());
|
||||
@@ -1237,6 +1473,111 @@ void appendScalarSendLoop(MaterializerState& state,
|
||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||
}
|
||||
|
||||
Value buildProjectedPackedPayload(MaterializerState& state,
|
||||
Operation* anchor,
|
||||
Value fullPayload,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
Value laneIndex,
|
||||
Location loc) {
|
||||
assert(descriptor.fragmentsPerLane > 1 && "use direct fragment path for single-fragment projection");
|
||||
|
||||
Value init = tensor::EmptyOp::create(
|
||||
state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType())
|
||||
.getResult();
|
||||
|
||||
Value lowerBound = createIndexConstant(state, anchor, 0);
|
||||
Value upperBound = createIndexConstant(state, anchor, descriptor.fragmentsPerLane);
|
||||
Value step = createIndexConstant(state, anchor, 1);
|
||||
|
||||
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
|
||||
|
||||
Block* body = loop.getBody();
|
||||
if (!body->empty())
|
||||
if (auto yield = dyn_cast<scf::YieldOp>(body->back()))
|
||||
state.rewriter.eraseOp(yield);
|
||||
|
||||
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||
state.rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value fragmentIndex = loop.getInductionVar();
|
||||
Value acc = body->getArgument(1);
|
||||
|
||||
Value fragmentsPerLane = createIndexConstant(state, anchor, descriptor.fragmentsPerLane);
|
||||
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
|
||||
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
||||
|
||||
Value sourceOffset =
|
||||
createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
||||
|
||||
Value fragment =
|
||||
createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
||||
|
||||
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
|
||||
scf::YieldOp::create(state.rewriter, loc, next);
|
||||
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
void appendProjectedScalarSendLoop(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
const ProjectedTransferDescriptor& descriptor,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
assert(!sourceClass.isBatch && "projected scalar send expects scalar source class");
|
||||
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size()
|
||||
&& "projected send lane count mismatch");
|
||||
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
|
||||
if (channelIds.size() == 1) {
|
||||
Value channelId = createIndexConstant(state, sourceClass.op, channelIds.front());
|
||||
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds.front());
|
||||
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds.front());
|
||||
Value laneIndex = createIndexConstant(state, sourceClass.op, 0);
|
||||
Value sendPayload;
|
||||
if (descriptor.fragmentsPerLane == 1) {
|
||||
Value offset = createIndexConstant(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front());
|
||||
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
|
||||
}
|
||||
else {
|
||||
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc);
|
||||
}
|
||||
|
||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
|
||||
return;
|
||||
}
|
||||
|
||||
Value lowerBound = createIndexConstant(state, sourceClass.op, 0);
|
||||
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
||||
Value step = createIndexConstant(state, sourceClass.op, 1);
|
||||
|
||||
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
||||
|
||||
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||
state.rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value index = loop.getInductionVar();
|
||||
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
|
||||
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
|
||||
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
|
||||
|
||||
Value sendPayload;
|
||||
if (descriptor.fragmentsPerLane == 1) {
|
||||
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc);
|
||||
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
|
||||
}
|
||||
else {
|
||||
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc);
|
||||
}
|
||||
|
||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
|
||||
}
|
||||
|
||||
void appendSend(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
@@ -1394,6 +1735,7 @@ SmallVector<ClassId, 4> collectDestinationClassesForKeys(MaterializerState& stat
|
||||
|
||||
SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
ArrayRef<ClassId> destinationClasses,
|
||||
Value payload,
|
||||
Location loc) {
|
||||
@@ -1401,25 +1743,44 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
||||
|
||||
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
||||
|
||||
size_t messageCount = 0;
|
||||
for (ClassId destinationClass : destinationClasses) {
|
||||
if (destinationClass == sourceClass.id)
|
||||
continue;
|
||||
|
||||
MaterializedClass& targetClass = state.classes[destinationClass];
|
||||
messageCount += targetClass.isBatch ? targetClass.cpus.size() : 1;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8> allChannelIds;
|
||||
SmallVector<int32_t, 8> allSourceCoreIds;
|
||||
SmallVector<int32_t, 8> allTargetCoreIds;
|
||||
allChannelIds.reserve(messageCount);
|
||||
allSourceCoreIds.reserve(messageCount);
|
||||
allTargetCoreIds.reserve(messageCount);
|
||||
|
||||
SmallVector<ScalarSourceReceivePlan, 4> receivePlans;
|
||||
receivePlans.reserve(destinationClasses.size());
|
||||
|
||||
const auto tryEmitProjected = [&](ClassId destinationClass,
|
||||
const SmallVector<int64_t, 8>& channelIds,
|
||||
const SmallVector<int32_t, 8>& sourceCoreIds,
|
||||
const SmallVector<int32_t, 8>& targetCoreIds) -> bool {
|
||||
if (keys.size() != 1)
|
||||
return false;
|
||||
|
||||
MaterializedClass& targetClass = state.classes[destinationClass];
|
||||
if (!targetClass.isBatch)
|
||||
return false;
|
||||
|
||||
auto producerIt = state.projectedTransfers.find(keys.front());
|
||||
if (producerIt == state.projectedTransfers.end())
|
||||
return false;
|
||||
|
||||
auto descriptorIt = producerIt->second.find(destinationClass);
|
||||
if (descriptorIt == producerIt->second.end())
|
||||
return false;
|
||||
|
||||
const ProjectedTransferDescriptor& descriptor = descriptorIt->second;
|
||||
if (descriptor.laneMajorSourceDim0Offsets.size()
|
||||
!= targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane))
|
||||
return false;
|
||||
|
||||
appendProjectedScalarSendLoop(
|
||||
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
|
||||
Value received = appendReceive(
|
||||
state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
|
||||
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
|
||||
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
|
||||
return true;
|
||||
};
|
||||
|
||||
for (ClassId destinationClass : destinationClasses) {
|
||||
if (destinationClass == sourceClass.id)
|
||||
continue;
|
||||
@@ -1435,10 +1796,6 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
||||
plan.channelIds.push_back(channelId);
|
||||
plan.sourceCoreIds.push_back(sourceCpu);
|
||||
plan.targetCoreIds.push_back(targetCpu);
|
||||
|
||||
allChannelIds.push_back(channelId);
|
||||
allSourceCoreIds.push_back(sourceCpu);
|
||||
allTargetCoreIds.push_back(targetCpu);
|
||||
};
|
||||
|
||||
if (!targetClass.isBatch)
|
||||
@@ -1447,12 +1804,13 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
||||
for (CpuId targetCpu : targetClass.cpus)
|
||||
appendMessage(static_cast<int32_t>(targetCpu));
|
||||
|
||||
if (tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds))
|
||||
continue;
|
||||
|
||||
appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc);
|
||||
receivePlans.push_back(std::move(plan));
|
||||
}
|
||||
|
||||
if (!allChannelIds.empty())
|
||||
appendSend(state, sourceClass, payload, allChannelIds, allSourceCoreIds, allTargetCoreIds, loc);
|
||||
|
||||
return receivePlans;
|
||||
}
|
||||
|
||||
@@ -1465,7 +1823,7 @@ LogicalResult emitScalarSourceCommunication(
|
||||
|
||||
SmallVector<ClassId, 4> destinationClasses = collectDestinationClassesForKeys(state, keys);
|
||||
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
|
||||
emitScalarSourceSends(state, sourceClass, destinationClasses, payload, loc);
|
||||
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc);
|
||||
|
||||
for (const ScalarSourceReceivePlan& plan : receivePlans) {
|
||||
MaterializedClass& targetClass = state.classes[plan.targetClass];
|
||||
@@ -1716,7 +2074,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
Value laneValue,
|
||||
ArrayRef<size_t> resultIndices);
|
||||
ArrayRef<size_t> resultIndices,
|
||||
std::optional<Value> projectionSlotIndex = std::nullopt);
|
||||
|
||||
FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
@@ -1766,7 +2125,7 @@ FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState&
|
||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices);
|
||||
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
|
||||
if (failed(produced))
|
||||
return failure();
|
||||
|
||||
@@ -1833,7 +2192,7 @@ FailureOr<Value> insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt
|
||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices);
|
||||
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
|
||||
if (failed(produced))
|
||||
return failure();
|
||||
|
||||
@@ -2153,6 +2512,19 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
|
||||
return appendInput(state, targetClass, input);
|
||||
}
|
||||
|
||||
bool hasProjectedInputReplacement(
|
||||
MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) {
|
||||
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
|
||||
if (!extract)
|
||||
return false;
|
||||
|
||||
auto replacementIt = state.projectedExtractReplacements.find(extract->getOperation());
|
||||
if (replacementIt == state.projectedExtractReplacements.end())
|
||||
return false;
|
||||
|
||||
return replacementIt->second.find(classId) != replacementIt->second.end();
|
||||
}
|
||||
|
||||
void mapWeights(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
@@ -2195,6 +2567,9 @@ LogicalResult mapInputs(MaterializerState& state,
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(op);
|
||||
for (auto [index, input] : llvm::enumerate(batch.getInputs())) {
|
||||
if (hasProjectedInputReplacement(state, batch, static_cast<unsigned>(index), targetClass.id))
|
||||
continue;
|
||||
|
||||
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
|
||||
if (failed(mapped))
|
||||
return batch.emitOpError("failed to resolve materialized compute_batch input");
|
||||
@@ -2262,6 +2637,35 @@ SmallVector<Type, 4> collectBatchOutputFragmentTypes(SpatComputeBatch batch) {
|
||||
return types;
|
||||
}
|
||||
|
||||
std::optional<ProjectedExtractReplacement> lookupProjectedExtractReplacement(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
tensor::ExtractSliceOp extract) {
|
||||
auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation());
|
||||
if (replacementIt == state.projectedExtractReplacements.end())
|
||||
return std::nullopt;
|
||||
|
||||
auto classIt = replacementIt->second.find(targetClass.id);
|
||||
if (classIt == replacementIt->second.end())
|
||||
return std::nullopt;
|
||||
|
||||
return classIt->second;
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
tensor::ExtractSliceOp extract,
|
||||
const ProjectedExtractReplacement& replacement,
|
||||
std::optional<Value> projectionSlotIndex) {
|
||||
if (replacement.fragmentsPerLane == 1)
|
||||
return replacement.payload;
|
||||
|
||||
if (!projectionSlotIndex)
|
||||
return targetClass.op->emitError("packed projected extract replacement requires a projection slot index");
|
||||
|
||||
return createDim0ExtractSlice(
|
||||
state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0));
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
|
||||
assert(!peers.empty() && "expected at least one peer instance");
|
||||
@@ -2301,6 +2705,19 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra
|
||||
|
||||
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
|
||||
for (Operation& op : sourceBlock.without_terminator()) {
|
||||
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||
if (std::optional<ProjectedExtractReplacement> replacement =
|
||||
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
||||
FailureOr<Value> projected =
|
||||
materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, std::nullopt);
|
||||
if (failed(projected))
|
||||
return failure();
|
||||
|
||||
mapper.map(extract.getResult(), *projected);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Operation* cloned = state.rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
@@ -2503,7 +2920,7 @@ LogicalResult emitPackedRunFanout(MaterializerState& state,
|
||||
assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class");
|
||||
|
||||
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
|
||||
emitScalarSourceSends(state, sourceClass, destinationClasses, packed, loc);
|
||||
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc);
|
||||
|
||||
for (const ScalarSourceReceivePlan& plan : receivePlans) {
|
||||
MaterializedClass& targetClass = state.classes[plan.targetClass];
|
||||
@@ -2522,7 +2939,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
Value laneValue,
|
||||
ArrayRef<size_t> resultIndices) {
|
||||
ArrayRef<size_t> resultIndices,
|
||||
std::optional<Value> projectionSlotIndex) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||
if (!batch)
|
||||
return failure();
|
||||
@@ -2544,6 +2962,19 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
||||
|
||||
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
|
||||
for (Operation& op : sourceBlock.without_terminator()) {
|
||||
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||
if (std::optional<ProjectedExtractReplacement> replacement =
|
||||
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
||||
FailureOr<Value> projected = materializeProjectedExtractReplacement(
|
||||
state, targetClass, extract, *replacement, projectionSlotIndex);
|
||||
if (failed(projected))
|
||||
return failure();
|
||||
|
||||
mapper.map(extract.getResult(), *projected);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Operation* cloned = state.rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
@@ -2637,7 +3068,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc);
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices);
|
||||
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex);
|
||||
if (failed(produced))
|
||||
return failure();
|
||||
|
||||
@@ -3127,7 +3558,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
|
||||
Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc);
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices);
|
||||
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex);
|
||||
if (failed(produced))
|
||||
return failure();
|
||||
|
||||
@@ -3286,6 +3717,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
|
||||
createEmptyMaterializedOps(state);
|
||||
if (failed(collectProducerDestinations(state)))
|
||||
return failure();
|
||||
if (failed(collectProjectedTransfers(state)))
|
||||
return failure();
|
||||
|
||||
for (const ComputeInstance& instance : schedule.dominanceOrderCompute)
|
||||
if (failed(materializeInstanceSlot(state, instance)))
|
||||
|
||||
Reference in New Issue
Block a user