From c0238c0d06def2bd9f16bc3ff445ec5b1c4ca4cd Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sat, 30 May 2026 16:12:06 +0200 Subject: [PATCH] fix high memory usage caused by MaterializeMergeSchedule.cpp with more robust code --- .../MaterializeMergeSchedule.cpp | 175 +++++++++++++++--- 1 file changed, 151 insertions(+), 24 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 5e4f8b3..d33061a 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -155,8 +155,9 @@ struct ProjectedTransferDescriptor { RankedTensorType fragmentType; RankedTensorType payloadType; + unsigned sourceProjectedDim = 0; unsigned fragmentsPerLane = 1; - SmallVector laneMajorSourceDim0Offsets; + SmallVector laneMajorProjectedOffsets; }; struct ProjectedExtractReplacement { @@ -656,6 +657,29 @@ Value createDim0ExtractSlice( .getResult(); } +Value createSingleDimExtractSlice(MaterializerState& state, + Location loc, + Value source, + unsigned sliceDim, + OpFoldResult offset, + ArrayRef resultShape) { + auto sourceType = cast(source.getType()); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(sourceType.getRank()); + sizes.reserve(sourceType.getRank()); + strides.reserve(sourceType.getRank()); + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + offsets.push_back(dim == static_cast(sliceDim) ? offset : OpFoldResult(state.rewriter.getIndexAttr(0))); + sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim])); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); +} + Value createDim0InsertSlice( MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { auto fragmentType = cast(fragment.getType()); @@ -1238,6 +1262,84 @@ bool isStaticIndexAttr(OpFoldResult value, int64_t expected) { return intAttr && intAttr.getInt() == expected; } +bool isStaticSliceInBounds(int64_t offset, + RankedTensorType sourceType, + RankedTensorType fragmentType, + unsigned sliceDim) { + if (offset < 0) + return false; + if (sliceDim >= static_cast(sourceType.getRank()) + || sliceDim >= static_cast(fragmentType.getRank())) + return false; + + int64_t sourceDimSize = sourceType.getDimSize(sliceDim); + int64_t fragmentDimSize = fragmentType.getDimSize(sliceDim); + if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) + return false; + + return offset <= sourceDimSize - fragmentDimSize; +} + +std::optional getLaneProjectedDim(ArrayRef offsets, Value laneArg) { + std::optional projectedDim; + for (auto [dim, offset] : llvm::enumerate(offsets)) { + if (!isValueOffset(offset, laneArg)) + continue; + + if (projectedDim) + return std::nullopt; + projectedDim = static_cast(dim); + } + + return projectedDim; +} + +static FailureOr evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) { + if (auto attr = dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return failure(); + return intAttr.getInt(); + } + + Value current = cast(value); + if (current == laneArg) + return static_cast(lane); + + if (auto constant = current.getDefiningOp()) + return constant.value(); + + if (auto constant = current.getDefiningOp()) + if (auto intAttr = dyn_cast(constant.getValue())) + return intAttr.getInt(); + + if (auto affineApply = current.getDefiningOp()) { + AffineMap map = affineApply.getAffineMap(); + if (map.getNumResults() != 1) + return failure(); + + SmallVector operandConstants; + operandConstants.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr folded = evaluateProjectedOffsetForLane(operand, laneArg, lane); + if (failed(folded)) + return failure(); + operandConstants.push_back(IntegerAttr::get(IndexType::get(current.getContext()), *folded)); + } + + SmallVector foldedResults; + if (failed(map.constantFold(operandConstants, foldedResults)) || foldedResults.size() != 1) + return failure(); + + auto constantResult = dyn_cast(foldedResults.front()); + if (!constantResult) + return failure(); + return constantResult.getInt(); + } + + return failure(); +} + std::optional matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) { std::optional inputArg = batch.getInputArgument(inputIndex); std::optional laneArg = batch.getLaneArgument(); @@ -1269,12 +1371,17 @@ std::optional matchSimpleLaneProjectedInput(SpatComputeB || strides.size() != static_cast(inputType.getRank())) return std::nullopt; - if (!isValueOffset(offsets.front(), *laneArg)) - return std::nullopt; - if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1)) + std::optional projectedDim = getLaneProjectedDim(offsets, *laneArg); + if (!projectedDim) return std::nullopt; - for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) { + if (dim == static_cast(*projectedDim)) { + if (!isStaticIndexAttr(sizes[dim], 1) || !isStaticIndexAttr(strides[dim], 1)) + return std::nullopt; + continue; + } + if (!isStaticIndexAttr(offsets[dim], 0)) return std::nullopt; if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim))) @@ -1283,11 +1390,11 @@ std::optional matchSimpleLaneProjectedInput(SpatComputeB return std::nullopt; } - if (fragmentType.getDimSize(0) != 1) - return std::nullopt; - for (int64_t dim = 1; dim < inputType.getRank(); ++dim) - if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim)) + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) { + int64_t expectedDimSize = dim == *projectedDim ? 1 : inputType.getDimSize(dim); + if (fragmentType.getDimSize(dim) != expectedDimSize) return std::nullopt; + } return extract; } @@ -1297,6 +1404,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { ProjectedBatchInputKey inputKey; Operation* extractOp = nullptr; RankedTensorType fragmentType; + unsigned sourceProjectedDim = 0; SmallVector, 8> offsetsByLane; bool invalid = false; }; @@ -1341,19 +1449,25 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { if (!extract) continue; + auto inputType = cast(extract->getSource().getType()); auto fragmentType = cast((*extract).getResult().getType()); + SmallVector offsets = extract->getMixedOffsets(); + std::optional sourceProjectedDim = getLaneProjectedDim(offsets, *batch.getLaneArgument()); + if (!sourceProjectedDim) + continue; PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId]; if (descriptor.offsetsByLane.empty()) { descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; descriptor.extractOp = extract->getOperation(); descriptor.fragmentType = fragmentType; + descriptor.sourceProjectedDim = *sourceProjectedDim; descriptor.offsetsByLane.resize(targetClass.cpus.size()); } ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation() - || descriptor.fragmentType != fragmentType) { + || descriptor.fragmentType != fragmentType || descriptor.sourceProjectedDim != *sourceProjectedDim) { descriptor.invalid = true; continue; } @@ -1363,7 +1477,14 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { continue; } - descriptor.offsetsByLane[targetLane].push_back(static_cast(consumer.laneStart)); + FailureOr offset = + evaluateProjectedOffsetForLane(offsets[*sourceProjectedDim], *batch.getLaneArgument(), consumer.laneStart); + if (failed(offset) || !isStaticSliceInBounds(*offset, inputType, fragmentType, *sourceProjectedDim)) { + descriptor.invalid = true; + continue; + } + + descriptor.offsetsByLane[targetLane].push_back(*offset); } } @@ -1402,10 +1523,11 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { descriptor.extractOp = pendingDescriptor.extractOp; descriptor.fragmentType = pendingDescriptor.fragmentType; descriptor.payloadType = payloadType; + descriptor.sourceProjectedDim = pendingDescriptor.sourceProjectedDim; descriptor.fragmentsPerLane = fragmentsPerLane; - descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane); + descriptor.laneMajorProjectedOffsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane); for (ArrayRef laneOffsets : pendingDescriptor.offsetsByLane) - llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets); + llvm::append_range(descriptor.laneMajorProjectedOffsets, laneOffsets); state.projectedTransfers[producer][targetClassId] = std::move(descriptor); } @@ -1532,11 +1654,12 @@ Value buildProjectedPackedPayload(MaterializerState& state, Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc); + Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorProjectedOffsets, flatIndex, loc); + Value fragment = createSingleDimExtractSlice( + state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape()); - Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0)); - - Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex); + Value packedOffset = scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc); + Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset); scf::YieldOp::create(state.rewriter, loc, next); return loop.getResult(0); @@ -1553,7 +1676,7 @@ void appendProjectedScalarSendLoop(MaterializerState& state, assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size() + assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorProjectedOffsets.size() && "projected send lane count mismatch"); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); @@ -1566,8 +1689,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state, Value sendPayload; if (descriptor.fragmentsPerLane == 1) { Value offset = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front()); - sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0)); + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorProjectedOffsets.front()); + sendPayload = createSingleDimExtractSlice( + state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); } else { sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc); @@ -1594,8 +1718,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state, Value sendPayload; if (descriptor.fragmentsPerLane == 1) { - Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc); - sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0)); + Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorProjectedOffsets, index, loc); + sendPayload = createSingleDimExtractSlice( + state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); } else { sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc); @@ -1792,7 +1917,7 @@ SmallVector emitScalarSourceSends(MaterializerState& return false; const ProjectedTransferDescriptor& descriptor = descriptorIt->second; - if (descriptor.laneMajorSourceDim0Offsets.size() + if (descriptor.laneMajorProjectedOffsets.size() != targetClass.cpus.size() * static_cast(descriptor.fragmentsPerLane)) return false; @@ -2701,8 +2826,10 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state if (!projectionSlotIndex) return targetClass.op->emitError("packed projected extract replacement requires a projection slot index"); + Value packedOffset = scaleIndexByDim0Size( + state, targetClass.op, *projectionSlotIndex, replacement.fragmentType.getDimSize(0), extract.getLoc()); return createDim0ExtractSlice( - state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0)); + state, extract.getLoc(), replacement.payload, packedOffset, replacement.fragmentType.getDimSize(0)); } FailureOr>