fix high memory usage caused by MaterializeMergeSchedule.cpp with more robust code
This commit is contained in:
+151
-24
@@ -155,8 +155,9 @@ struct ProjectedTransferDescriptor {
|
||||
|
||||
RankedTensorType fragmentType;
|
||||
RankedTensorType payloadType;
|
||||
unsigned sourceProjectedDim = 0;
|
||||
unsigned fragmentsPerLane = 1;
|
||||
SmallVector<int64_t, 16> laneMajorSourceDim0Offsets;
|
||||
SmallVector<int64_t, 16> laneMajorProjectedOffsets;
|
||||
};
|
||||
|
||||
struct ProjectedExtractReplacement {
|
||||
@@ -656,6 +657,29 @@ Value createDim0ExtractSlice(
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value createSingleDimExtractSlice(MaterializerState& state,
|
||||
Location loc,
|
||||
Value source,
|
||||
unsigned sliceDim,
|
||||
OpFoldResult offset,
|
||||
ArrayRef<int64_t> resultShape) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<OpFoldResult, 4> offsets;
|
||||
SmallVector<OpFoldResult, 4> sizes;
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
offsets.reserve(sourceType.getRank());
|
||||
sizes.reserve(sourceType.getRank());
|
||||
strides.reserve(sourceType.getRank());
|
||||
|
||||
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||
offsets.push_back(dim == static_cast<int64_t>(sliceDim) ? offset : OpFoldResult(state.rewriter.getIndexAttr(0)));
|
||||
sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim]));
|
||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
Value createDim0InsertSlice(
|
||||
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
|
||||
auto fragmentType = cast<RankedTensorType>(fragment.getType());
|
||||
@@ -1238,6 +1262,84 @@ bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
|
||||
return intAttr && intAttr.getInt() == expected;
|
||||
}
|
||||
|
||||
bool isStaticSliceInBounds(int64_t offset,
|
||||
RankedTensorType sourceType,
|
||||
RankedTensorType fragmentType,
|
||||
unsigned sliceDim) {
|
||||
if (offset < 0)
|
||||
return false;
|
||||
if (sliceDim >= static_cast<unsigned>(sourceType.getRank())
|
||||
|| sliceDim >= static_cast<unsigned>(fragmentType.getRank()))
|
||||
return false;
|
||||
|
||||
int64_t sourceDimSize = sourceType.getDimSize(sliceDim);
|
||||
int64_t fragmentDimSize = fragmentType.getDimSize(sliceDim);
|
||||
if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize)
|
||||
return false;
|
||||
|
||||
return offset <= sourceDimSize - fragmentDimSize;
|
||||
}
|
||||
|
||||
std::optional<unsigned> getLaneProjectedDim(ArrayRef<OpFoldResult> offsets, Value laneArg) {
|
||||
std::optional<unsigned> projectedDim;
|
||||
for (auto [dim, offset] : llvm::enumerate(offsets)) {
|
||||
if (!isValueOffset(offset, laneArg))
|
||||
continue;
|
||||
|
||||
if (projectedDim)
|
||||
return std::nullopt;
|
||||
projectedDim = static_cast<unsigned>(dim);
|
||||
}
|
||||
|
||||
return projectedDim;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) {
|
||||
if (auto attr = dyn_cast<Attribute>(value)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!intAttr)
|
||||
return failure();
|
||||
return intAttr.getInt();
|
||||
}
|
||||
|
||||
Value current = cast<Value>(value);
|
||||
if (current == laneArg)
|
||||
return static_cast<int64_t>(lane);
|
||||
|
||||
if (auto constant = current.getDefiningOp<arith::ConstantIndexOp>())
|
||||
return constant.value();
|
||||
|
||||
if (auto constant = current.getDefiningOp<arith::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
|
||||
return intAttr.getInt();
|
||||
|
||||
if (auto affineApply = current.getDefiningOp<affine::AffineApplyOp>()) {
|
||||
AffineMap map = affineApply.getAffineMap();
|
||||
if (map.getNumResults() != 1)
|
||||
return failure();
|
||||
|
||||
SmallVector<Attribute, 4> operandConstants;
|
||||
operandConstants.reserve(affineApply.getMapOperands().size());
|
||||
for (Value operand : affineApply.getMapOperands()) {
|
||||
FailureOr<int64_t> folded = evaluateProjectedOffsetForLane(operand, laneArg, lane);
|
||||
if (failed(folded))
|
||||
return failure();
|
||||
operandConstants.push_back(IntegerAttr::get(IndexType::get(current.getContext()), *folded));
|
||||
}
|
||||
|
||||
SmallVector<Attribute, 1> foldedResults;
|
||||
if (failed(map.constantFold(operandConstants, foldedResults)) || foldedResults.size() != 1)
|
||||
return failure();
|
||||
|
||||
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
||||
if (!constantResult)
|
||||
return failure();
|
||||
return constantResult.getInt();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) {
|
||||
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
|
||||
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
|
||||
@@ -1269,12 +1371,17 @@ std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeB
|
||||
|| strides.size() != static_cast<size_t>(inputType.getRank()))
|
||||
return std::nullopt;
|
||||
|
||||
if (!isValueOffset(offsets.front(), *laneArg))
|
||||
return std::nullopt;
|
||||
if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1))
|
||||
std::optional<unsigned> projectedDim = getLaneProjectedDim(offsets, *laneArg);
|
||||
if (!projectedDim)
|
||||
return std::nullopt;
|
||||
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
|
||||
if (dim == static_cast<int64_t>(*projectedDim)) {
|
||||
if (!isStaticIndexAttr(sizes[dim], 1) || !isStaticIndexAttr(strides[dim], 1))
|
||||
return std::nullopt;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isStaticIndexAttr(offsets[dim], 0))
|
||||
return std::nullopt;
|
||||
if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim)))
|
||||
@@ -1283,11 +1390,11 @@ std::optional<tensor::ExtractSliceOp> matchSimpleLaneProjectedInput(SpatComputeB
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (fragmentType.getDimSize(0) != 1)
|
||||
return std::nullopt;
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim)
|
||||
if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim))
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
|
||||
int64_t expectedDimSize = dim == *projectedDim ? 1 : inputType.getDimSize(dim);
|
||||
if (fragmentType.getDimSize(dim) != expectedDimSize)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return extract;
|
||||
}
|
||||
@@ -1297,6 +1404,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
Operation* extractOp = nullptr;
|
||||
RankedTensorType fragmentType;
|
||||
unsigned sourceProjectedDim = 0;
|
||||
SmallVector<SmallVector<int64_t, 4>, 8> offsetsByLane;
|
||||
bool invalid = false;
|
||||
};
|
||||
@@ -1341,19 +1449,25 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
if (!extract)
|
||||
continue;
|
||||
|
||||
auto inputType = cast<RankedTensorType>(extract->getSource().getType());
|
||||
auto fragmentType = cast<RankedTensorType>((*extract).getResult().getType());
|
||||
SmallVector<OpFoldResult, 4> offsets = extract->getMixedOffsets();
|
||||
std::optional<unsigned> sourceProjectedDim = getLaneProjectedDim(offsets, *batch.getLaneArgument());
|
||||
if (!sourceProjectedDim)
|
||||
continue;
|
||||
|
||||
PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId];
|
||||
if (descriptor.offsetsByLane.empty()) {
|
||||
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
descriptor.extractOp = extract->getOperation();
|
||||
descriptor.fragmentType = fragmentType;
|
||||
descriptor.sourceProjectedDim = *sourceProjectedDim;
|
||||
descriptor.offsetsByLane.resize(targetClass.cpus.size());
|
||||
}
|
||||
|
||||
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation()
|
||||
|| descriptor.fragmentType != fragmentType) {
|
||||
|| descriptor.fragmentType != fragmentType || descriptor.sourceProjectedDim != *sourceProjectedDim) {
|
||||
descriptor.invalid = true;
|
||||
continue;
|
||||
}
|
||||
@@ -1363,7 +1477,14 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
continue;
|
||||
}
|
||||
|
||||
descriptor.offsetsByLane[targetLane].push_back(static_cast<int64_t>(consumer.laneStart));
|
||||
FailureOr<int64_t> offset =
|
||||
evaluateProjectedOffsetForLane(offsets[*sourceProjectedDim], *batch.getLaneArgument(), consumer.laneStart);
|
||||
if (failed(offset) || !isStaticSliceInBounds(*offset, inputType, fragmentType, *sourceProjectedDim)) {
|
||||
descriptor.invalid = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
descriptor.offsetsByLane[targetLane].push_back(*offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1402,10 +1523,11 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
descriptor.extractOp = pendingDescriptor.extractOp;
|
||||
descriptor.fragmentType = pendingDescriptor.fragmentType;
|
||||
descriptor.payloadType = payloadType;
|
||||
descriptor.sourceProjectedDim = pendingDescriptor.sourceProjectedDim;
|
||||
descriptor.fragmentsPerLane = fragmentsPerLane;
|
||||
descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
|
||||
descriptor.laneMajorProjectedOffsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane);
|
||||
for (ArrayRef<int64_t> laneOffsets : pendingDescriptor.offsetsByLane)
|
||||
llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets);
|
||||
llvm::append_range(descriptor.laneMajorProjectedOffsets, laneOffsets);
|
||||
|
||||
state.projectedTransfers[producer][targetClassId] = std::move(descriptor);
|
||||
}
|
||||
@@ -1532,11 +1654,12 @@ Value buildProjectedPackedPayload(MaterializerState& state,
|
||||
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
|
||||
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
||||
|
||||
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc);
|
||||
Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorProjectedOffsets, flatIndex, loc);
|
||||
Value fragment = createSingleDimExtractSlice(
|
||||
state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape());
|
||||
|
||||
Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0));
|
||||
|
||||
Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex);
|
||||
Value packedOffset = scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc);
|
||||
Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset);
|
||||
scf::YieldOp::create(state.rewriter, loc, next);
|
||||
|
||||
return loop.getResult(0);
|
||||
@@ -1553,7 +1676,7 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
|
||||
assert(!sourceClass.isBatch && "projected scalar send expects scalar source class");
|
||||
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size()
|
||||
assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorProjectedOffsets.size()
|
||||
&& "projected send lane count mismatch");
|
||||
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
@@ -1566,8 +1689,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
|
||||
Value sendPayload;
|
||||
if (descriptor.fragmentsPerLane == 1) {
|
||||
Value offset =
|
||||
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front());
|
||||
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
|
||||
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorProjectedOffsets.front());
|
||||
sendPayload = createSingleDimExtractSlice(
|
||||
state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape());
|
||||
}
|
||||
else {
|
||||
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc);
|
||||
@@ -1594,8 +1718,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
|
||||
|
||||
Value sendPayload;
|
||||
if (descriptor.fragmentsPerLane == 1) {
|
||||
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc);
|
||||
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
|
||||
Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorProjectedOffsets, index, loc);
|
||||
sendPayload = createSingleDimExtractSlice(
|
||||
state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape());
|
||||
}
|
||||
else {
|
||||
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc);
|
||||
@@ -1792,7 +1917,7 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
|
||||
return false;
|
||||
|
||||
const ProjectedTransferDescriptor& descriptor = descriptorIt->second;
|
||||
if (descriptor.laneMajorSourceDim0Offsets.size()
|
||||
if (descriptor.laneMajorProjectedOffsets.size()
|
||||
!= targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane))
|
||||
return false;
|
||||
|
||||
@@ -2701,8 +2826,10 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
|
||||
if (!projectionSlotIndex)
|
||||
return targetClass.op->emitError("packed projected extract replacement requires a projection slot index");
|
||||
|
||||
Value packedOffset = scaleIndexByDim0Size(
|
||||
state, targetClass.op, *projectionSlotIndex, replacement.fragmentType.getDimSize(0), extract.getLoc());
|
||||
return createDim0ExtractSlice(
|
||||
state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0));
|
||||
state, extract.getLoc(), replacement.payload, packedOffset, replacement.fragmentType.getDimSize(0));
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
|
||||
Reference in New Issue
Block a user