fix high memory usage in IR
This commit is contained in:
+464
-31
@@ -124,6 +124,49 @@ struct BatchRunSendPlan {
|
|||||||
SmallVector<int32_t, 16> targetCoreIds;
|
SmallVector<int32_t, 16> targetCoreIds;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ProjectedBatchInputKey {
|
||||||
|
Operation* consumerOp = nullptr;
|
||||||
|
unsigned inputIndex = 0;
|
||||||
|
|
||||||
|
bool operator==(const ProjectedBatchInputKey& other) const {
|
||||||
|
return consumerOp == other.consumerOp && inputIndex == other.inputIndex;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ProjectedBatchInputKeyInfo {
|
||||||
|
static ProjectedBatchInputKey getEmptyKey() {
|
||||||
|
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<unsigned>::max()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static ProjectedBatchInputKey getTombstoneKey() {
|
||||||
|
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<unsigned>::max()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned getHashValue(const ProjectedBatchInputKey& key) {
|
||||||
|
return llvm::hash_combine(key.consumerOp, key.inputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ProjectedTransferDescriptor {
|
||||||
|
ProjectedBatchInputKey inputKey;
|
||||||
|
Operation* extractOp = nullptr;
|
||||||
|
|
||||||
|
RankedTensorType fragmentType;
|
||||||
|
RankedTensorType payloadType;
|
||||||
|
unsigned fragmentsPerLane = 1;
|
||||||
|
SmallVector<int64_t, 16> laneMajorSourceDim0Offsets;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ProjectedExtractReplacement {
|
||||||
|
Value payload;
|
||||||
|
RankedTensorType fragmentType;
|
||||||
|
unsigned fragmentsPerLane = 1;
|
||||||
|
};
|
||||||
|
|
||||||
struct MaterializerState;
|
struct MaterializerState;
|
||||||
|
|
||||||
class AvailableValueStore {
|
class AvailableValueStore {
|
||||||
@@ -158,6 +201,8 @@ struct MaterializerState {
|
|||||||
DenseSet<ClassSlotKey> materializedSlots;
|
DenseSet<ClassSlotKey> materializedSlots;
|
||||||
|
|
||||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||||
|
DenseMap<ProducerKey, DenseMap<ClassId, ProjectedTransferDescriptor>, ProducerKeyInfo> projectedTransfers;
|
||||||
|
DenseMap<Operation*, DenseMap<ClassId, ProjectedExtractReplacement>> projectedExtractReplacements;
|
||||||
AvailableValueStore availableValues;
|
AvailableValueStore availableValues;
|
||||||
DenseMap<Value, Value> hostReplacements;
|
DenseMap<Value, Value> hostReplacements;
|
||||||
DenseSet<Operation*> oldComputeOps;
|
DenseSet<Operation*> oldComputeOps;
|
||||||
@@ -1153,6 +1198,197 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isValueOffset(OpFoldResult offset, Value expected) {
|
||||||
|
auto value = dyn_cast<Value>(offset);
|
||||||
|
return value && value == expected;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
|
||||||
|
auto attr = dyn_cast<Attribute>(value);
|
||||||
|
if (!attr)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
|
return intAttr && intAttr.getInt() == expected;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) {
|
||||||
|
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
|
||||||
|
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
|
||||||
|
if (!inputArg || !laneArg)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (!inputArg->hasOneUse())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
Operation* user = *inputArg->getUsers().begin();
|
||||||
|
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||||
|
if (!extract || extract.getSource() != *inputArg)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(inputArg->getType());
|
||||||
|
auto fragmentType = dyn_cast<RankedTensorType>(extract.getResult().getType());
|
||||||
|
if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult, 4> offsets = extract.getMixedOffsets();
|
||||||
|
SmallVector<OpFoldResult, 4> sizes = extract.getMixedSizes();
|
||||||
|
SmallVector<OpFoldResult, 4> strides = extract.getMixedStrides();
|
||||||
|
|
||||||
|
if (offsets.size() != static_cast<size_t>(inputType.getRank())
|
||||||
|
|| sizes.size() != static_cast<size_t>(inputType.getRank())
|
||||||
|
|| strides.size() != static_cast<size_t>(inputType.getRank()))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (!isValueOffset(offsets.front(), *laneArg))
|
||||||
|
return std::nullopt;
|
||||||
|
if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||||
|
if (!isStaticIndexAttr(offsets[dim], 0))
|
||||||
|
return std::nullopt;
|
||||||
|
if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim)))
|
||||||
|
return std::nullopt;
|
||||||
|
if (!isStaticIndexAttr(strides[dim], 1))
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fragmentType.getDimSize(0) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim)
|
||||||
|
if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return extract;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||||
|
struct PendingProjectedTransferDescriptor {
|
||||||
|
ProjectedBatchInputKey inputKey;
|
||||||
|
Operation* extractOp = nullptr;
|
||||||
|
RankedTensorType fragmentType;
|
||||||
|
SmallVector<SmallVector<int64_t, 4>, 8> offsetsByLane;
|
||||||
|
bool invalid = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
|
||||||
|
|
||||||
|
for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) {
|
||||||
|
auto batch = dyn_cast<SpatComputeBatch>(consumer.op);
|
||||||
|
if (!batch || consumer.laneCount != 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto cpuIt = state.schedule.computeToCpuMap.find(consumer);
|
||||||
|
if (cpuIt == state.schedule.computeToCpuMap.end())
|
||||||
|
return consumer.op->emitError("projected transfer collection expected scheduled consumer");
|
||||||
|
|
||||||
|
ClassId targetClassId = state.cpuToClass.lookup(cpuIt->second);
|
||||||
|
MaterializedClass& targetClass = state.classes[targetClassId];
|
||||||
|
if (!targetClass.isBatch)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto targetLaneIt = targetClass.cpuToLane.find(cpuIt->second);
|
||||||
|
if (targetLaneIt == targetClass.cpuToLane.end())
|
||||||
|
return consumer.op->emitError("projected transfer collection could not recover target lane");
|
||||||
|
|
||||||
|
unsigned targetLane = targetLaneIt->second;
|
||||||
|
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
|
||||||
|
std::optional<ProducerKey> producer = getProducerKey(input, &consumer);
|
||||||
|
if (!producer)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto producerCpuIt = state.schedule.computeToCpuMap.find(producer->instance);
|
||||||
|
if (producerCpuIt == state.schedule.computeToCpuMap.end())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second);
|
||||||
|
if (sourceClassId == targetClassId)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
std::optional<tensor::ExtractSliceOp> extract =
|
||||||
|
matchSimpleLaneProjectedInput(batch, static_cast<unsigned>(inputIndex));
|
||||||
|
if (!extract)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto fragmentType = cast<RankedTensorType>((*extract).getResult().getType());
|
||||||
|
|
||||||
|
PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId];
|
||||||
|
if (descriptor.offsetsByLane.empty()) {
|
||||||
|
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||||
|
descriptor.extractOp = extract->getOperation();
|
||||||
|
descriptor.fragmentType = fragmentType;
|
||||||
|
descriptor.offsetsByLane.resize(targetClass.cpus.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||||
|
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation()
|
||||||
|
|| descriptor.fragmentType != fragmentType) {
|
||||||
|
descriptor.invalid = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (targetLane >= descriptor.offsetsByLane.size()) {
|
||||||
|
descriptor.invalid = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
descriptor.offsetsByLane[targetLane].push_back(static_cast<int64_t>(consumer.laneStart));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& producerEntry : pending) {
|
||||||
|
ProducerKey producer = producerEntry.first;
|
||||||
|
for (auto& classEntry : producerEntry.second) {
|
||||||
|
ClassId targetClassId = classEntry.first;
|
||||||
|
PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second;
|
||||||
|
|
||||||
|
if (pendingDescriptor.invalid)
|
||||||
|
continue;
|
||||||
|
if (pendingDescriptor.offsetsByLane.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
unsigned fragmentsPerLane = pendingDescriptor.offsetsByLane.front().size();
|
||||||
|
if (fragmentsPerLane == 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
bool uniform = true;
|
||||||
|
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane) {
|
||||||
|
if (laneOffsets.size() != fragmentsPerLane) {
|
||||||
|
uniform = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!uniform)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> payloadShape(pendingDescriptor.fragmentType.getShape());
|
||||||
|
payloadShape[0] *= static_cast<int64_t>(fragmentsPerLane);
|
||||||
|
RankedTensorType payloadType = RankedTensorType::get(payloadShape,
|
||||||
|
pendingDescriptor.fragmentType.getElementType(),
|
||||||
|
pendingDescriptor.fragmentType.getEncoding());
|
||||||
|
|
||||||
|
ProjectedTransferDescriptor descriptor;
|
||||||
|
descriptor.inputKey = pendingDescriptor.inputKey;
|
||||||
|
descriptor.extractOp = pendingDescriptor.extractOp;
|
||||||
|
descriptor.fragmentType = pendingDescriptor.fragmentType;
|
||||||
|
descriptor.payloadType = payloadType;
|
||||||
|
descriptor.fragmentsPerLane = fragmentsPerLane;
|
||||||
|
descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
|
||||||
|
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane)
|
||||||
|
llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets);
|
||||||
|
|
||||||
|
state.projectedTransfers[producer][targetClassId] = std::move(descriptor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<ProducerKey, 8> getOutputKeysForPeers(ArrayRef<ComputeInstance> peers, size_t resultIndex) {
|
SmallVector<ProducerKey, 8> getOutputKeysForPeers(ArrayRef<ComputeInstance> peers, size_t resultIndex) {
|
||||||
SmallVector<ProducerKey, 8> keys;
|
SmallVector<ProducerKey, 8> keys;
|
||||||
keys.reserve(peers.size());
|
keys.reserve(peers.size());
|
||||||
@@ -1237,6 +1473,111 @@ void appendScalarSendLoop(MaterializerState& state,
|
|||||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value buildProjectedPackedPayload(MaterializerState& state,
|
||||||
|
Operation* anchor,
|
||||||
|
Value fullPayload,
|
||||||
|
const ProjectedTransferDescriptor& descriptor,
|
||||||
|
Value laneIndex,
|
||||||
|
Location loc) {
|
||||||
|
assert(descriptor.fragmentsPerLane > 1 && "use direct fragment path for single-fragment projection");
|
||||||
|
|
||||||
|
Value init = tensor::EmptyOp::create(
|
||||||
|
state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType())
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
Value lowerBound = createIndexConstant(state, anchor, 0);
|
||||||
|
Value upperBound = createIndexConstant(state, anchor, descriptor.fragmentsPerLane);
|
||||||
|
Value step = createIndexConstant(state, anchor, 1);
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
|
||||||
|
|
||||||
|
Block* body = loop.getBody();
|
||||||
|
if (!body->empty())
|
||||||
|
if (auto yield = dyn_cast<scf::YieldOp>(body->back()))
|
||||||
|
state.rewriter.eraseOp(yield);
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||||
|
state.rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
|
Value fragmentIndex = loop.getInductionVar();
|
||||||
|
Value acc = body->getArgument(1);
|
||||||
|
|
||||||
|
Value fragmentsPerLane = createIndexConstant(state, anchor, descriptor.fragmentsPerLane);
|
||||||
|
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
|
||||||
|
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
||||||
|
|
||||||
|
Value sourceOffset =
|
||||||
|
createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
||||||
|
|
||||||
|
Value fragment =
|
||||||
|
createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
||||||
|
|
||||||
|
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
|
||||||
|
scf::YieldOp::create(state.rewriter, loc, next);
|
||||||
|
|
||||||
|
return loop.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void appendProjectedScalarSendLoop(MaterializerState& state,
|
||||||
|
MaterializedClass& sourceClass,
|
||||||
|
Value payload,
|
||||||
|
const ProjectedTransferDescriptor& descriptor,
|
||||||
|
ArrayRef<int64_t> channelIds,
|
||||||
|
ArrayRef<int32_t> sourceCoreIds,
|
||||||
|
ArrayRef<int32_t> targetCoreIds,
|
||||||
|
Location loc) {
|
||||||
|
assert(!sourceClass.isBatch && "projected scalar send expects scalar source class");
|
||||||
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||||
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
|
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size()
|
||||||
|
&& "projected send lane count mismatch");
|
||||||
|
|
||||||
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
|
||||||
|
if (channelIds.size() == 1) {
|
||||||
|
Value channelId = createIndexConstant(state, sourceClass.op, channelIds.front());
|
||||||
|
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds.front());
|
||||||
|
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds.front());
|
||||||
|
Value laneIndex = createIndexConstant(state, sourceClass.op, 0);
|
||||||
|
Value sendPayload;
|
||||||
|
if (descriptor.fragmentsPerLane == 1) {
|
||||||
|
Value offset = createIndexConstant(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front());
|
||||||
|
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value lowerBound = createIndexConstant(state, sourceClass.op, 0);
|
||||||
|
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
||||||
|
Value step = createIndexConstant(state, sourceClass.op, 1);
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||||
|
state.rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
Value index = loop.getInductionVar();
|
||||||
|
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
|
||||||
|
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
|
||||||
|
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
|
||||||
|
|
||||||
|
Value sendPayload;
|
||||||
|
if (descriptor.fragmentsPerLane == 1) {
|
||||||
|
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc);
|
||||||
|
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
|
||||||
|
}
|
||||||
|
|
||||||
void appendSend(MaterializerState& state,
|
void appendSend(MaterializerState& state,
|
||||||
MaterializedClass& sourceClass,
|
MaterializedClass& sourceClass,
|
||||||
Value payload,
|
Value payload,
|
||||||
@@ -1394,6 +1735,7 @@ SmallVector<ClassId, 4> collectDestinationClassesForKeys(MaterializerState& stat
|
|||||||
|
|
||||||
SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState& state,
|
SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState& state,
|
||||||
MaterializedClass& sourceClass,
|
MaterializedClass& sourceClass,
|
||||||
|
ArrayRef<ProducerKey> keys,
|
||||||
ArrayRef<ClassId> destinationClasses,
|
ArrayRef<ClassId> destinationClasses,
|
||||||
Value payload,
|
Value payload,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
@@ -1401,25 +1743,44 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
|||||||
|
|
||||||
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
||||||
|
|
||||||
size_t messageCount = 0;
|
|
||||||
for (ClassId destinationClass : destinationClasses) {
|
|
||||||
if (destinationClass == sourceClass.id)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
MaterializedClass& targetClass = state.classes[destinationClass];
|
|
||||||
messageCount += targetClass.isBatch ? targetClass.cpus.size() : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t, 8> allChannelIds;
|
|
||||||
SmallVector<int32_t, 8> allSourceCoreIds;
|
|
||||||
SmallVector<int32_t, 8> allTargetCoreIds;
|
|
||||||
allChannelIds.reserve(messageCount);
|
|
||||||
allSourceCoreIds.reserve(messageCount);
|
|
||||||
allTargetCoreIds.reserve(messageCount);
|
|
||||||
|
|
||||||
SmallVector<ScalarSourceReceivePlan, 4> receivePlans;
|
SmallVector<ScalarSourceReceivePlan, 4> receivePlans;
|
||||||
receivePlans.reserve(destinationClasses.size());
|
receivePlans.reserve(destinationClasses.size());
|
||||||
|
|
||||||
|
const auto tryEmitProjected = [&](ClassId destinationClass,
|
||||||
|
const SmallVector<int64_t, 8>& channelIds,
|
||||||
|
const SmallVector<int32_t, 8>& sourceCoreIds,
|
||||||
|
const SmallVector<int32_t, 8>& targetCoreIds) -> bool {
|
||||||
|
if (keys.size() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
MaterializedClass& targetClass = state.classes[destinationClass];
|
||||||
|
if (!targetClass.isBatch)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto producerIt = state.projectedTransfers.find(keys.front());
|
||||||
|
if (producerIt == state.projectedTransfers.end())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto descriptorIt = producerIt->second.find(destinationClass);
|
||||||
|
if (descriptorIt == producerIt->second.end())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const ProjectedTransferDescriptor& descriptor = descriptorIt->second;
|
||||||
|
if (descriptor.laneMajorSourceDim0Offsets.size()
|
||||||
|
!= targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
appendProjectedScalarSendLoop(
|
||||||
|
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
|
Value received = appendReceive(
|
||||||
|
state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
|
state.projectedExtractReplacements[descriptor.extractOp][destinationClass] =
|
||||||
|
ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane};
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
for (ClassId destinationClass : destinationClasses) {
|
for (ClassId destinationClass : destinationClasses) {
|
||||||
if (destinationClass == sourceClass.id)
|
if (destinationClass == sourceClass.id)
|
||||||
continue;
|
continue;
|
||||||
@@ -1435,10 +1796,6 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
|||||||
plan.channelIds.push_back(channelId);
|
plan.channelIds.push_back(channelId);
|
||||||
plan.sourceCoreIds.push_back(sourceCpu);
|
plan.sourceCoreIds.push_back(sourceCpu);
|
||||||
plan.targetCoreIds.push_back(targetCpu);
|
plan.targetCoreIds.push_back(targetCpu);
|
||||||
|
|
||||||
allChannelIds.push_back(channelId);
|
|
||||||
allSourceCoreIds.push_back(sourceCpu);
|
|
||||||
allTargetCoreIds.push_back(targetCpu);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!targetClass.isBatch)
|
if (!targetClass.isBatch)
|
||||||
@@ -1447,12 +1804,13 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
|||||||
for (CpuId targetCpu : targetClass.cpus)
|
for (CpuId targetCpu : targetClass.cpus)
|
||||||
appendMessage(static_cast<int32_t>(targetCpu));
|
appendMessage(static_cast<int32_t>(targetCpu));
|
||||||
|
|
||||||
|
if (tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc);
|
||||||
receivePlans.push_back(std::move(plan));
|
receivePlans.push_back(std::move(plan));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!allChannelIds.empty())
|
|
||||||
appendSend(state, sourceClass, payload, allChannelIds, allSourceCoreIds, allTargetCoreIds, loc);
|
|
||||||
|
|
||||||
return receivePlans;
|
return receivePlans;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1465,7 +1823,7 @@ LogicalResult emitScalarSourceCommunication(
|
|||||||
|
|
||||||
SmallVector<ClassId, 4> destinationClasses = collectDestinationClassesForKeys(state, keys);
|
SmallVector<ClassId, 4> destinationClasses = collectDestinationClassesForKeys(state, keys);
|
||||||
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
|
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
|
||||||
emitScalarSourceSends(state, sourceClass, destinationClasses, payload, loc);
|
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc);
|
||||||
|
|
||||||
for (const ScalarSourceReceivePlan& plan : receivePlans) {
|
for (const ScalarSourceReceivePlan& plan : receivePlans) {
|
||||||
MaterializedClass& targetClass = state.classes[plan.targetClass];
|
MaterializedClass& targetClass = state.classes[plan.targetClass];
|
||||||
@@ -1716,7 +2074,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
|||||||
MaterializedClass& targetClass,
|
MaterializedClass& targetClass,
|
||||||
const ComputeInstance& instance,
|
const ComputeInstance& instance,
|
||||||
Value laneValue,
|
Value laneValue,
|
||||||
ArrayRef<size_t> resultIndices);
|
ArrayRef<size_t> resultIndices,
|
||||||
|
std::optional<Value> projectionSlotIndex = std::nullopt);
|
||||||
|
|
||||||
FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState& state,
|
FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState& state,
|
||||||
MaterializedClass& targetClass,
|
MaterializedClass& targetClass,
|
||||||
@@ -1766,7 +2125,7 @@ FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState&
|
|||||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>> produced =
|
FailureOr<SmallVector<Value, 4>> produced =
|
||||||
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices);
|
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
|
||||||
if (failed(produced))
|
if (failed(produced))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -1833,7 +2192,7 @@ FailureOr<Value> insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt
|
|||||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>> produced =
|
FailureOr<SmallVector<Value, 4>> produced =
|
||||||
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices);
|
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
|
||||||
if (failed(produced))
|
if (failed(produced))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -2153,6 +2512,19 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
|
|||||||
return appendInput(state, targetClass, input);
|
return appendInput(state, targetClass, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasProjectedInputReplacement(
|
||||||
|
MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) {
|
||||||
|
std::optional<tensor::ExtractSliceOp> extract = matchSimpleLaneProjectedInput(batch, inputIndex);
|
||||||
|
if (!extract)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto replacementIt = state.projectedExtractReplacements.find(extract->getOperation());
|
||||||
|
if (replacementIt == state.projectedExtractReplacements.end())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return replacementIt->second.find(classId) != replacementIt->second.end();
|
||||||
|
}
|
||||||
|
|
||||||
void mapWeights(MaterializerState& state,
|
void mapWeights(MaterializerState& state,
|
||||||
MaterializedClass& targetClass,
|
MaterializedClass& targetClass,
|
||||||
const ComputeInstance& instance,
|
const ComputeInstance& instance,
|
||||||
@@ -2195,6 +2567,9 @@ LogicalResult mapInputs(MaterializerState& state,
|
|||||||
|
|
||||||
auto batch = cast<SpatComputeBatch>(op);
|
auto batch = cast<SpatComputeBatch>(op);
|
||||||
for (auto [index, input] : llvm::enumerate(batch.getInputs())) {
|
for (auto [index, input] : llvm::enumerate(batch.getInputs())) {
|
||||||
|
if (hasProjectedInputReplacement(state, batch, static_cast<unsigned>(index), targetClass.id))
|
||||||
|
continue;
|
||||||
|
|
||||||
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
|
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
|
||||||
if (failed(mapped))
|
if (failed(mapped))
|
||||||
return batch.emitOpError("failed to resolve materialized compute_batch input");
|
return batch.emitOpError("failed to resolve materialized compute_batch input");
|
||||||
@@ -2262,6 +2637,35 @@ SmallVector<Type, 4> collectBatchOutputFragmentTypes(SpatComputeBatch batch) {
|
|||||||
return types;
|
return types;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<ProjectedExtractReplacement> lookupProjectedExtractReplacement(MaterializerState& state,
|
||||||
|
MaterializedClass& targetClass,
|
||||||
|
tensor::ExtractSliceOp extract) {
|
||||||
|
auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation());
|
||||||
|
if (replacementIt == state.projectedExtractReplacements.end())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto classIt = replacementIt->second.find(targetClass.id);
|
||||||
|
if (classIt == replacementIt->second.end())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return classIt->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state,
|
||||||
|
MaterializedClass& targetClass,
|
||||||
|
tensor::ExtractSliceOp extract,
|
||||||
|
const ProjectedExtractReplacement& replacement,
|
||||||
|
std::optional<Value> projectionSlotIndex) {
|
||||||
|
if (replacement.fragmentsPerLane == 1)
|
||||||
|
return replacement.payload;
|
||||||
|
|
||||||
|
if (!projectionSlotIndex)
|
||||||
|
return targetClass.op->emitError("packed projected extract replacement requires a projection slot index");
|
||||||
|
|
||||||
|
return createDim0ExtractSlice(
|
||||||
|
state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0));
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>>
|
FailureOr<SmallVector<Value, 4>>
|
||||||
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
|
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
|
||||||
assert(!peers.empty() && "expected at least one peer instance");
|
assert(!peers.empty() && "expected at least one peer instance");
|
||||||
@@ -2301,6 +2705,19 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra
|
|||||||
|
|
||||||
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
|
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
|
||||||
for (Operation& op : sourceBlock.without_terminator()) {
|
for (Operation& op : sourceBlock.without_terminator()) {
|
||||||
|
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||||
|
if (std::optional<ProjectedExtractReplacement> replacement =
|
||||||
|
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
||||||
|
FailureOr<Value> projected =
|
||||||
|
materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, std::nullopt);
|
||||||
|
if (failed(projected))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
mapper.map(extract.getResult(), *projected);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Operation* cloned = state.rewriter.clone(op, mapper);
|
Operation* cloned = state.rewriter.clone(op, mapper);
|
||||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
mapper.map(oldResult, newResult);
|
mapper.map(oldResult, newResult);
|
||||||
@@ -2503,7 +2920,7 @@ LogicalResult emitPackedRunFanout(MaterializerState& state,
|
|||||||
assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class");
|
assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class");
|
||||||
|
|
||||||
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
|
SmallVector<ScalarSourceReceivePlan, 4> receivePlans =
|
||||||
emitScalarSourceSends(state, sourceClass, destinationClasses, packed, loc);
|
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc);
|
||||||
|
|
||||||
for (const ScalarSourceReceivePlan& plan : receivePlans) {
|
for (const ScalarSourceReceivePlan& plan : receivePlans) {
|
||||||
MaterializedClass& targetClass = state.classes[plan.targetClass];
|
MaterializedClass& targetClass = state.classes[plan.targetClass];
|
||||||
@@ -2522,7 +2939,8 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
|||||||
MaterializedClass& targetClass,
|
MaterializedClass& targetClass,
|
||||||
const ComputeInstance& instance,
|
const ComputeInstance& instance,
|
||||||
Value laneValue,
|
Value laneValue,
|
||||||
ArrayRef<size_t> resultIndices) {
|
ArrayRef<size_t> resultIndices,
|
||||||
|
std::optional<Value> projectionSlotIndex) {
|
||||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||||
if (!batch)
|
if (!batch)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -2544,6 +2962,19 @@ FailureOr<SmallVector<Value, 4>> cloneBatchBodyForLane(MaterializerState& state,
|
|||||||
|
|
||||||
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
|
Block& sourceBlock = getComputeInstanceTemplateBlock(instance);
|
||||||
for (Operation& op : sourceBlock.without_terminator()) {
|
for (Operation& op : sourceBlock.without_terminator()) {
|
||||||
|
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||||
|
if (std::optional<ProjectedExtractReplacement> replacement =
|
||||||
|
lookupProjectedExtractReplacement(state, targetClass, extract)) {
|
||||||
|
FailureOr<Value> projected = materializeProjectedExtractReplacement(
|
||||||
|
state, targetClass, extract, *replacement, projectionSlotIndex);
|
||||||
|
if (failed(projected))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
mapper.map(extract.getResult(), *projected);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Operation* cloned = state.rewriter.clone(op, mapper);
|
Operation* cloned = state.rewriter.clone(op, mapper);
|
||||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
mapper.map(oldResult, newResult);
|
mapper.map(oldResult, newResult);
|
||||||
@@ -2637,7 +3068,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
|||||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc);
|
Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc);
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>> produced =
|
FailureOr<SmallVector<Value, 4>> produced =
|
||||||
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices);
|
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex);
|
||||||
if (failed(produced))
|
if (failed(produced))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -3127,7 +3558,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
|
|||||||
Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc);
|
Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc);
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>> produced =
|
FailureOr<SmallVector<Value, 4>> produced =
|
||||||
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices);
|
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex);
|
||||||
if (failed(produced))
|
if (failed(produced))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -3286,6 +3717,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
|
|||||||
createEmptyMaterializedOps(state);
|
createEmptyMaterializedOps(state);
|
||||||
if (failed(collectProducerDestinations(state)))
|
if (failed(collectProducerDestinations(state)))
|
||||||
return failure();
|
return failure();
|
||||||
|
if (failed(collectProjectedTransfers(state)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
for (const ComputeInstance& instance : schedule.dominanceOrderCompute)
|
for (const ComputeInstance& instance : schedule.dominanceOrderCompute)
|
||||||
if (failed(materializeInstanceSlot(state, instance)))
|
if (failed(materializeInstanceSlot(state, instance)))
|
||||||
|
|||||||
Reference in New Issue
Block a user