fix high memory usage caused by MaterializeMergeSchedule.cpp with more robust code

This commit is contained in:
NiccoloN
2026-05-30 16:12:06 +02:00
parent ff36729140
commit c0238c0d06
@@ -155,8 +155,9 @@ struct ProjectedTransferDescriptor {
RankedTensorType fragmentType;
RankedTensorType payloadType;
unsigned sourceProjectedDim = 0;
unsigned fragmentsPerLane = 1;
SmallVector<int64_t, 16> laneMajorSourceDim0Offsets;
SmallVector<int64_t, 16> laneMajorProjectedOffsets;
};
struct ProjectedExtractReplacement {
@@ -656,6 +657,29 @@ Value createDim0ExtractSlice(
.getResult();
}
Value createSingleDimExtractSlice(MaterializerState& state,
Location loc,
Value source,
unsigned sliceDim,
OpFoldResult offset,
ArrayRef<int64_t> resultShape) {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<OpFoldResult, 4> offsets;
SmallVector<OpFoldResult, 4> sizes;
SmallVector<OpFoldResult, 4> strides;
offsets.reserve(sourceType.getRank());
sizes.reserve(sourceType.getRank());
strides.reserve(sourceType.getRank());
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
offsets.push_back(dim == static_cast<int64_t>(sliceDim) ? offset : OpFoldResult(state.rewriter.getIndexAttr(0)));
sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim]));
strides.push_back(state.rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult();
}
Value createDim0InsertSlice(
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
auto fragmentType = cast<RankedTensorType>(fragment.getType());
@@ -1238,6 +1262,84 @@ bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
return intAttr && intAttr.getInt() == expected;
}
bool isStaticSliceInBounds(int64_t offset,
RankedTensorType sourceType,
RankedTensorType fragmentType,
unsigned sliceDim) {
if (offset < 0)
return false;
if (sliceDim >= static_cast<unsigned>(sourceType.getRank())
|| sliceDim >= static_cast<unsigned>(fragmentType.getRank()))
return false;
int64_t sourceDimSize = sourceType.getDimSize(sliceDim);
int64_t fragmentDimSize = fragmentType.getDimSize(sliceDim);
if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize)
return false;
return offset <= sourceDimSize - fragmentDimSize;
}
std::optional<unsigned> getLaneProjectedDim(ArrayRef<OpFoldResult> offsets, Value laneArg) {
std::optional<unsigned> projectedDim;
for (auto [dim, offset] : llvm::enumerate(offsets)) {
if (!isValueOffset(offset, laneArg))
continue;
if (projectedDim)
return std::nullopt;
projectedDim = static_cast<unsigned>(dim);
}
return projectedDim;
}
static FailureOr<int64_t> evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) {
if (auto attr = dyn_cast<Attribute>(value)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr)
return failure();
return intAttr.getInt();
}
Value current = cast<Value>(value);
if (current == laneArg)
return static_cast<int64_t>(lane);
if (auto constant = current.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
if (auto constant = current.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
return intAttr.getInt();
if (auto affineApply = current.getDefiningOp<affine::AffineApplyOp>()) {
AffineMap map = affineApply.getAffineMap();
if (map.getNumResults() != 1)
return failure();
SmallVector<Attribute, 4> operandConstants;
operandConstants.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = evaluateProjectedOffsetForLane(operand, laneArg, lane);
if (failed(folded))
return failure();
operandConstants.push_back(IntegerAttr::get(IndexType::get(current.getContext()), *folded));
}
SmallVector<Attribute, 1> foldedResults;
if (failed(map.constantFold(operandConstants, foldedResults)) || foldedResults.size() != 1)
return failure();
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (!constantResult)
return failure();
return constantResult.getInt();
}
return failure();
}
std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) {
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
@@ -1269,12 +1371,17 @@ std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeB
|| strides.size() != static_cast<size_t>(inputType.getRank()))
return std::nullopt;
if (!isValueOffset(offsets.front(), *laneArg))
return std::nullopt;
if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1))
std::optional<unsigned> projectedDim = getLaneProjectedDim(offsets, *laneArg);
if (!projectedDim)
return std::nullopt;
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
if (dim == static_cast<int64_t>(*projectedDim)) {
if (!isStaticIndexAttr(sizes[dim], 1) || !isStaticIndexAttr(strides[dim], 1))
return std::nullopt;
continue;
}
if (!isStaticIndexAttr(offsets[dim], 0))
return std::nullopt;
if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim)))
@@ -1283,11 +1390,11 @@ std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeB
return std::nullopt;
}
if (fragmentType.getDimSize(0) != 1)
return std::nullopt;
for (int64_t dim = 1; dim < inputType.getRank(); ++dim)
if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim))
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
int64_t expectedDimSize = dim == *projectedDim ? 1 : inputType.getDimSize(dim);
if (fragmentType.getDimSize(dim) != expectedDimSize)
return std::nullopt;
}
return extract;
}
@@ -1297,6 +1404,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
ProjectedBatchInputKey inputKey;
Operation* extractOp = nullptr;
RankedTensorType fragmentType;
unsigned sourceProjectedDim = 0;
SmallVector<SmallVector<int64_t, 4>, 8> offsetsByLane;
bool invalid = false;
};
@@ -1341,19 +1449,25 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
if (!extract)
continue;
auto inputType = cast<RankedTensorType>(extract->getSource().getType());
auto fragmentType = cast<RankedTensorType>((*extract).getResult().getType());
SmallVector<OpFoldResult, 4> offsets = extract->getMixedOffsets();
std::optional<unsigned> sourceProjectedDim = getLaneProjectedDim(offsets, *batch.getLaneArgument());
if (!sourceProjectedDim)
continue;
PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId];
if (descriptor.offsetsByLane.empty()) {
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
descriptor.extractOp = extract->getOperation();
descriptor.fragmentType = fragmentType;
descriptor.sourceProjectedDim = *sourceProjectedDim;
descriptor.offsetsByLane.resize(targetClass.cpus.size());
}
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation()
|| descriptor.fragmentType != fragmentType) {
|| descriptor.fragmentType != fragmentType || descriptor.sourceProjectedDim != *sourceProjectedDim) {
descriptor.invalid = true;
continue;
}
@@ -1363,7 +1477,14 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
continue;
}
descriptor.offsetsByLane[targetLane].push_back(static_cast<int64_t>(consumer.laneStart));
FailureOr<int64_t> offset =
evaluateProjectedOffsetForLane(offsets[*sourceProjectedDim], *batch.getLaneArgument(), consumer.laneStart);
if (failed(offset) || !isStaticSliceInBounds(*offset, inputType, fragmentType, *sourceProjectedDim)) {
descriptor.invalid = true;
continue;
}
descriptor.offsetsByLane[targetLane].push_back(*offset);
}
}
@@ -1402,10 +1523,11 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
descriptor.extractOp = pendingDescriptor.extractOp;
descriptor.fragmentType = pendingDescriptor.fragmentType;
descriptor.payloadType = payloadType;
descriptor.sourceProjectedDim = pendingDescriptor.sourceProjectedDim;
descriptor.fragmentsPerLane = fragmentsPerLane;
descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
descriptor.laneMajorProjectedOffsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane)
llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets);
llvm::append_range(descriptor.laneMajorProjectedOffsets, laneOffsets);
state.projectedTransfers[producer][targetClassId] = std::move(descriptor);
}
@@ -1532,11 +1654,12 @@ Value buildProjectedPackedPayload(MaterializerState& state,
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorProjectedOffsets, flatIndex, loc);
Value fragment = createSingleDimExtractSlice(
state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape());
Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
Value packedOffset = scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc);
Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset);
scf::YieldOp::create(state.rewriter, loc, next);
return loop.getResult(0);
@@ -1553,7 +1676,7 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
assert(!sourceClass.isBatch && "projected scalar send expects scalar source class");
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size()
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorProjectedOffsets.size()
&& "projected send lane count mismatch");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
@@ -1566,8 +1689,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
Value sendPayload;
if (descriptor.fragmentsPerLane == 1) {
Value offset =
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front());
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorProjectedOffsets.front());
sendPayload = createSingleDimExtractSlice(
state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape());
}
else {
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc);
@@ -1594,8 +1718,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
Value sendPayload;
if (descriptor.fragmentsPerLane == 1) {
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc);
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorProjectedOffsets, index, loc);
sendPayload = createSingleDimExtractSlice(
state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape());
}
else {
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc);
@@ -1792,7 +1917,7 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
return false;
const ProjectedTransferDescriptor& descriptor = descriptorIt->second;
if (descriptor.laneMajorSourceDim0Offsets.size()
if (descriptor.laneMajorProjectedOffsets.size()
!= targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane))
return false;
@@ -2701,8 +2826,10 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
if (!projectionSlotIndex)
return targetClass.op->emitError("packed projected extract replacement requires a projection slot index");
Value packedOffset = scaleIndexByDim0Size(
state, targetClass.op, *projectionSlotIndex, replacement.fragmentType.getDimSize(0), extract.getLoc());
return createDim0ExtractSlice(
state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0));
state, extract.getLoc(), replacement.payload, packedOffset, replacement.fragmentType.getDimSize(0));
}
FailureOr<SmallVector<Value, 4>>