From e8f09fd67f99540b0bc21a0fafada4c2eac425d2 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 29 Jun 2026 12:22:33 +0200 Subject: [PATCH] robba --- src/PIM/Common/IR/AddressAnalysis.cpp | 48 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 66 +- .../BatchCoreLoweringPatterns.cpp | 397 ++++----- src/PIM/Conversion/SpatialToPim/Common.cpp | 376 ++++++++ src/PIM/Conversion/SpatialToPim/Common.hpp | 39 + .../SpatialToPim/CoreLoweringPatterns.cpp | 81 +- .../SpatialToPim/ReturnPathNormalization.cpp | 35 +- .../MaterializeMergeSchedule.cpp | 826 +++++++++++++++--- 8 files changed, 1376 insertions(+), 492 deletions(-) diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index 6e6add0..21d93ec 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -34,12 +34,25 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg llvm::FailureOr compileIndexValueImpl(mlir::Value value); llvm::FailureOr compileContiguousAddressExprImpl(mlir::Value value); +mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge); template CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) { return CompiledIndexExpr(std::make_shared(std::forward(args)...)); } +static mlir::Value resolveForYieldedAliasToInit(mlir::scf::ForOp forOp, + mlir::Value yieldedValue, + const StaticValueKnowledge* knowledge) { + yieldedValue = resolveLoopCarriedAliasImpl(yieldedValue, knowledge); + if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { + if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 + && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) + return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); + } + return yieldedValue; +} + mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); @@ -60,15 +73,8 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow auto result = mlir::dyn_cast(value); if (result) { auto yieldOp = mlir::dyn_cast(forOp.getBody()->getTerminator()); - if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) { - mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); - if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { - if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 - && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) - return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); - } - return yieldedValue; - } + if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) + return resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge); } } @@ -515,16 +521,7 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va return mlir::failure(); auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); - mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); - if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { - if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 - && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { - value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); - continue; - } - } - - value = yieldedValue; + value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge); continue; } @@ -643,16 +640,7 @@ llvm::FailureOr compileContiguousAddressExprImpl(mlir::Valu return mlir::failure(); auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); - mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber()); - if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { - if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 - && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { - value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1]; - continue; - } - } - - value = yieldedValue; + value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), nullptr); continue; } @@ -862,7 +850,7 @@ llvm::FailureOr CompiledAddressExpr::evaluate(const S auto resolvedOffset = byteOffset.evaluate(knowledge); if (failed(resolvedOffset)) return mlir::failure(); - return ResolvedContiguousAddress {base, *resolvedOffset}; + return ResolvedContiguousAddress {resolveAlias(base, &knowledge), *resolvedOffset}; } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 23080a4..a76fd6d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1334,6 +1334,38 @@ static Value affineAddConst( return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor); } +static Value affineMulConst( + PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) { + if (factor == 1) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor); +} + +static Value affineFloorDivConst( + PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { + assert(divisor > 0 && "expected positive affine floordiv divisor"); + if (divisor == 1) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor); +} + +static Value affineModConst( + PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) { + assert(modulus > 0 && "expected positive affine mod divisor"); + if (modulus == 1) + return getOrCreateIndexConstant(rewriter, constantAnchor, 0); + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor); +} + static Value createConvInputPatch(Value input, RankedTensorType patchType, Value batchIndex, @@ -2414,11 +2446,6 @@ static Value createIm2colRows(const ConvLoweringState& state, Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches); - Value cChunkStart = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkStart); - Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch); - Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); - Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight); - Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth); auto im2colLoop = buildNormalizedScfFor( rewriter, @@ -2429,13 +2456,17 @@ static Value createIm2colRows(const ConvLoweringState& state, ValueRange {im2colInit}, [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value im2colAcc = iterArgs.front(); - Value globalPatchIndex = arith::AddIOp::create(rewriter, nestedLoc, patchIndex, cChunkStart); - Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch); - Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch); - Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); - Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); - Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); - Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth); + Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp); + Value batchIndex = + affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp); + Value batchPatchIndex = + affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp); + Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp); + Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp); + Value inputHeightOffset = + affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp); + Value inputWidthOffset = + affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp); auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType); @@ -2844,11 +2875,9 @@ static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices); - Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim); Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn); - Value cPatchWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.wHeight * state.wWidth); - Value localHeightOffset = arith::MulIOp::create(rewriter, loc, args.lane, c1); + Value localHeightOffset = args.lane; Value packedRowInit = tensor::EmptyOp::create(rewriter, loc, ArrayRef {1, state.outWidth, state.numChannelsOut}, elementType); auto widthLoop = buildNormalizedScfFor( @@ -2859,7 +2888,7 @@ static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, c1, ValueRange {packedRowInit}, [&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl& widthYielded) { - Value localWidthOffset = arith::MulIOp::create(rewriter, widthLoc, widthIndex, c1); + Value localWidthOffset = widthIndex; Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef {1, patchSize}, elementType); auto rowLoop = buildNormalizedScfFor( rewriter, @@ -2878,7 +2907,8 @@ static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, rewriter, rowLoc, flatPatchType, channelPatch, SmallVector {{0, 1, 2}}); Value rowChunk = tensor::ExpandShapeOp::create( rewriter, rowLoc, rowChunkType, flatPatch, SmallVector {{0, 1}}); - Value flatOffset = arith::MulIOp::create(rewriter, rowLoc, channel, cPatchWidth); + Value flatOffset = affineMulConst( + rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp); SmallVector rowOffsets {rewriter.getIndexAttr(0), flatOffset}; SmallVector rowSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)}; @@ -2905,7 +2935,7 @@ static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, c1, ValueRange {zeroRow}, [&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl& reduceYielded) { - Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar); + Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp); SmallVector aOffsets {rewriter.getIndexAttr(0), kOffset}; SmallVector aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)}; Value aTile = tensor::ExtractSliceOp::create( diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 59c03c7..8f32848 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -10,6 +11,7 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" @@ -131,46 +133,92 @@ static FailureOr getDirectReturnOperandIndex(OpResult result) { return result.getUses().begin()->getOperandNumber(); } -struct BatchFragmentAssemblyPlan { - unsigned returnIndex = 0; - int64_t localSourceElementOffset = 0; - int64_t fragmentByteSize = 0; - SmallVector hostOffsetsByLane; -}; +static FailureOr> +collectFragmentAssemblyCopiesFromBlueprint(spatial::SpatBlueprintOp blueprint, + IRMapping& mapper, + int64_t lane, + unsigned hostTargetIndex, + Value fixedSource = {}) { + SmallVector copies; + auto resultType = dyn_cast(blueprint.getOutput().getType()); + if (!resultType || !resultType.hasStaticShape()) + return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results"); -static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef values, Location loc) { - assert(!values.empty() && "expected lane-indexed values"); - if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); })) - return getOrCreateIndexConstant(rewriter, anchor, values.front()); + std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); + std::optional> fragmentStridesAttr = blueprint.getFragmentStrides(); + if (!operandIndicesAttr || !fragmentStridesAttr) + return blueprint.emitOpError( + "fragment assembly lowering requires explicit operand indices and unit strides"); - if (values.size() >= 2) { - int64_t step = values[1] - values[0]; - bool arithmetic = llvm::all_of(llvm::seq(2, values.size()), [&](size_t index) { - return values[index] == values.front() + static_cast(index) * step; - }); - if (arithmetic) { - Value base = getOrCreateIndexConstant(rewriter, anchor, values.front()); - if (step == 0) - return base; - Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step); - Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult(); - return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult(); + ArrayRef operandIndices = *operandIndicesAttr; + std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); + if (!sourceOffsetsAttr) + return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets"); + ArrayRef sourceOffsets = *sourceOffsetsAttr; + ArrayRef flatOffsets = blueprint.getFragmentOffsets(); + ArrayRef flatSizes = blueprint.getFragmentSizes(); + ArrayRef flatStrides = *fragmentStridesAttr; + int64_t rank = resultType.getRank(); + + SmallVector fragmentOperands {blueprint.getInput()}; + llvm::append_range(fragmentOperands, blueprint.getFragments()); + if (failed(validateFragmentAssemblyMetadata(blueprint, + rank, + fragmentOperands.size(), + operandIndices, + sourceOffsets, + flatOffsets, + flatSizes, + flatStrides))) + return failure(); + + SmallVector hostStrides = computeRowMajorStrides(resultType.getShape()); + for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { + Value source = fixedSource ? fixedSource : mapper.lookupOrDefault(fragmentOperands[operandIndices[fragmentIndex]]); + auto sourceType = dyn_cast(source.getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + + size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType()); + SmallVector fragmentOffsets; + SmallVector fragmentSizes; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + if (flatStrides[flatIndex] != 1) + return blueprint.emitOpError("fragment assembly lowering only supports unit strides"); + fragmentOffsets.push_back(flatOffsets[flatIndex]); + fragmentSizes.push_back(flatSizes[flatIndex]); } + + if (failed(forEachContiguousDestinationChunk( + resultType.getShape(), + fragmentOffsets, + fragmentSizes, + [&](ArrayRef chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { + int64_t hostElementOffset = 0; + for (auto [dim, offset] : llvm::enumerate(chunkOffsets)) + hostElementOffset += offset * hostStrides[dim]; + + FragmentAssemblyCopy copy; + copy.source = source; + copy.sourceType = sourceType; + copy.hostTargetIndex = hostTargetIndex; + copy.lane = lane; + copy.sourceByteOffset = (sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast(elementSize); + copy.hostByteOffset = hostElementOffset * static_cast(elementSize); + copy.byteSize = chunkElements * static_cast(elementSize); + copies.push_back(copy); + return success(); + }))) + return failure(); } - Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front()); - for (auto [lane, value] : llvm::enumerate(values.drop_front())) { - Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast(lane + 1)); - Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue); - Value candidate = getOrCreateIndexConstant(rewriter, anchor, value); - selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected); - } - return selected; + return copies; } -static FailureOr> -analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) { - SmallVector plans; +static FailureOr> +collectTopLevelFragmentAssemblyCopies(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) { + SmallVector copies; if (!packedResultType.hasStaticShape() || laneCount == 0) return failure(); @@ -187,15 +235,14 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu std::optional mode = blueprint.getMode(); std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); - std::optional> stridesAttr = blueprint.getFragmentStrides(); - if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) + if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr) return failure(); if (!blueprint.getOutput().hasOneUse() || !isa(*blueprint.getOutput().getUsers().begin())) return failure(); - unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber(); auto hostResultType = dyn_cast(blueprint.getOutput().getType()); - if (!hostResultType || !hostResultType.hasStaticShape()) + std::optional> stridesAttr = blueprint.getFragmentStrides(); + if (!hostResultType || !hostResultType.hasStaticShape() || !stridesAttr) return failure(); ArrayRef operandIndices = *operandIndicesAttr; @@ -204,6 +251,7 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu ArrayRef flatSizes = blueprint.getFragmentSizes(); ArrayRef flatStrides = *stridesAttr; int64_t rank = hostResultType.getRank(); + unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber(); SmallVector fragmentOperands {blueprint.getInput()}; llvm::append_range(fragmentOperands, blueprint.getFragments()); if (failed(validateFragmentAssemblyMetadata(blueprint, @@ -215,16 +263,15 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu flatSizes, flatStrides))) return failure(); + SmallVector hostStrides = computeRowMajorStrides(hostResultType.getShape()); for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { if (operandIndices[fragmentIndex] != static_cast(use.getOperandNumber())) continue; int64_t sourceElementOffset = sourceOffsets[fragmentIndex]; int64_t lane = sourceElementOffset / payloadElementCount; - int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount; if (lane < 0 || lane >= static_cast(laneCount)) return failure(); - SmallVector fragmentOffsets; SmallVector fragmentSizes; for (int64_t dim = 0; dim < rank; ++dim) { @@ -236,44 +283,31 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu } if (failed(forEachContiguousDestinationChunk( - hostResultType.getShape(), - fragmentOffsets, - fragmentSizes, - [&](ArrayRef chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { - int64_t hostElementOffset = 0; - SmallVector hostStrides = computeRowMajorStrides(hostResultType.getShape()); - for (auto [dim, offset] : llvm::enumerate(chunkOffsets)) - hostElementOffset += offset * hostStrides[dim]; - int64_t hostByteOffset = hostElementOffset * static_cast(elementSize); - int64_t fragmentByteSize = chunkElements * static_cast(elementSize); - int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset; + hostResultType.getShape(), + fragmentOffsets, + fragmentSizes, + [&](ArrayRef chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { + int64_t hostElementOffset = 0; + for (auto [dim, offset] : llvm::enumerate(chunkOffsets)) + hostElementOffset += offset * hostStrides[dim]; - auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) { - return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset - && plan.fragmentByteSize == fragmentByteSize; - }); - if (planIt == plans.end()) { - BatchFragmentAssemblyPlan plan; - plan.returnIndex = returnIndex; - plan.localSourceElementOffset = chunkSourceOffset; - plan.fragmentByteSize = fragmentByteSize; - plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits::min()); - plan.hostOffsetsByLane[static_cast(lane)] = hostByteOffset; - plans.push_back(std::move(plan)); - return success(); - } - - planIt->hostOffsetsByLane[static_cast(lane)] = hostByteOffset; - return success(); - }))) + FragmentAssemblyCopy copy; + copy.source = result; + copy.sourceType = packedResultType; + copy.hostTargetIndex = returnIndex; + copy.lane = lane; + copy.sourceByteOffset = + ((sourceElementOffset % payloadElementCount) + relativeSourceOffset) * static_cast(elementSize); + copy.hostByteOffset = hostElementOffset * static_cast(elementSize); + copy.byteSize = chunkElements * static_cast(elementSize); + copies.push_back(copy); + return success(); + }))) return failure(); } } - for (const BatchFragmentAssemblyPlan& plan : plans) - if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits::min(); })) - return failure(); - return plans; + return copies; } static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) { @@ -284,22 +318,6 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult(); } -static SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { - SmallVector attrs; - attrs.reserve(values.size()); - for (int64_t value : values) - attrs.push_back(builder.getIndexAttr(value)); - return attrs; -} - -static SmallVector getUnitStrides(Builder& builder, int64_t rank) { - SmallVector strides; - strides.reserve(rank); - for (int64_t dim = 0; dim < rank; ++dim) - strides.push_back(builder.getIndexAttr(1)); - return strides; -} - static Value createHostTargetOffset(IRRewriter& rewriter, Location loc, ShapedType destinationType, @@ -351,123 +369,6 @@ static Value createHostTargetOffset(IRRewriter& rewriter, mapper); } -static SmallVector buildFragmentOffsets(IRRewriter& rewriter, - Location loc, - ArrayRef baseOffsets, - ArrayRef fragmentOffsets, - IRMapping& mapper) { - SmallVector combined; - combined.reserve(fragmentOffsets.size()); - for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) { - if (auto attr = dyn_cast(baseOffset)) { - int64_t base = cast(attr).getInt(); - combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim])); - continue; - } - - Value dynamicBase = mapper.lookupOrDefault(cast(baseOffset)); - if (fragmentOffsets[dim] == 0) { - combined.push_back(dynamicBase); - continue; - } - - Value staticOffset = - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]); - combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult()); - } - return combined; -} - -static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, - spatial::SpatBlueprintOp blueprint, - Value hostTarget, - ArrayRef baseOffsets, - IRMapping& mapper) { - auto hostTargetType = dyn_cast(hostTarget.getType()); - auto resultType = dyn_cast(blueprint.getOutput().getType()); - if (!hostTargetType || !resultType || !resultType.hasStaticShape()) - return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results"); - - std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); - std::optional> fragmentStridesAttr = blueprint.getFragmentStrides(); - if (!operandIndicesAttr || !fragmentStridesAttr) - return blueprint.emitOpError( - "fragment assembly lowering requires explicit operand indices and unit strides"); - - ArrayRef operandIndices = *operandIndicesAttr; - std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); - if (!sourceOffsetsAttr) - return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets"); - ArrayRef sourceOffsets = *sourceOffsetsAttr; - ArrayRef flatOffsets = blueprint.getFragmentOffsets(); - ArrayRef flatSizes = blueprint.getFragmentSizes(); - ArrayRef flatStrides = *fragmentStridesAttr; - int64_t rank = resultType.getRank(); - - SmallVector fragmentOperands {blueprint.getInput()}; - llvm::append_range(fragmentOperands, blueprint.getFragments()); - if (failed(validateFragmentAssemblyMetadata(blueprint, - rank, - fragmentOperands.size(), - operandIndices, - sourceOffsets, - flatOffsets, - flatSizes, - flatStrides))) - return failure(); - - for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { - int64_t operandIndex = operandIndices[fragmentIndex]; - - SmallVector fragmentOffsets; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t flatIndex = fragmentIndex * rank + dim; - if (flatStrides[flatIndex] != 1) - return blueprint.emitOpError("fragment assembly lowering only supports unit strides"); - fragmentOffsets.push_back(flatOffsets[flatIndex]); - } - - Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]); - auto sourceType = dyn_cast(source.getType()); - if (!sourceType || !sourceType.hasStaticShape()) - return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands"); - - SmallVector fragmentShape; - fragmentShape.reserve(rank); - for (int64_t dim = 0; dim < rank; ++dim) - fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]); - - Value fragment = source; - if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) { - FailureOr> extractOffsets = getStaticSliceOffsetsForElementOffset( - blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice"); - if (failed(extractOffsets)) - return failure(); - fragment = tensor::ExtractSliceOp::create(rewriter, - blueprint.getLoc(), - source, - getStaticIndexAttrs(rewriter, *extractOffsets), - getStaticIndexAttrs(rewriter, fragmentShape), - getUnitStrides(rewriter, rank)); - } - - hostTarget = tensor::InsertSliceOp::create(rewriter, - blueprint.getLoc(), - fragment, - hostTarget, - buildFragmentOffsets(rewriter, - blueprint.getLoc(), - baseOffsets, - fragmentOffsets, - mapper), - getStaticIndexAttrs(rewriter, fragmentShape), - getUnitStrides(rewriter, rank)) - .getResult(); - } - - return hostTarget; -} - } // namespace LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, @@ -505,10 +406,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds)); SmallVector returnOperandIndices; - SmallVector, 4> fragmentAssemblyPlansByResult; + SmallVector, 4> fragmentAssemblyRunsByResult; if (computeBatchOp.getNumResults() != 0) { returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits::max()); - fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults()); + fragmentAssemblyRunsByResult.resize(computeBatchOp.getNumResults()); for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) { if (result.use_empty()) continue; @@ -522,12 +423,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul if (!resultType || !resultType.hasStaticShape()) return computeBatchOp.emitOpError( "resultful compute_batch publication lowering requires static ranked tensor results"); - FailureOr> fragmentAssemblyPlans = - analyzeTopLevelFragmentAssemblyUses(cast(result), resultType, computeBatchOp.getLaneCount()); - if (failed(fragmentAssemblyPlans)) - return computeBatchOp.emitOpError( - "resultful compute_batch lowering currently requires each result to be used directly by func.return"); - fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end()); + FailureOr> fragmentAssemblyCopies = + collectTopLevelFragmentAssemblyCopies(cast(result), resultType, computeBatchOp.getLaneCount()); + if (failed(fragmentAssemblyCopies)) + return computeBatchOp.emitOpError("failed to collect top-level fragment assembly copies for compute_batch result"); + FailureOr> fragmentAssemblyRuns = + groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, computeBatchOp.getLaneCount()); + if (failed(fragmentAssemblyRuns)) + return computeBatchOp.emitOpError("failed to group top-level fragment assembly copies into regular runs"); + fragmentAssemblyRunsByResult[resultIndex].assign(fragmentAssemblyRuns->begin(), fragmentAssemblyRuns->end()); } } @@ -614,8 +518,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul if (resultIndex >= returnOperandIndices.size()) return insertSlice.emitOpError("result index out of range while lowering host batch output"); bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits::max(); - bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size() - && !fragmentAssemblyPlansByResult[resultIndex].empty(); + bool hasFragmentAssembly = resultIndex < fragmentAssemblyRunsByResult.size() + && !fragmentAssemblyRunsByResult[resultIndex].empty(); if (!hasDirectReturn && !hasFragmentAssembly) continue; @@ -626,27 +530,23 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul auto mappedSourceType = dyn_cast(mappedSource.getType()); if (!mappedSourceType || !mappedSourceType.hasStaticShape()) return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source"); - for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) { - Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc()); - auto sizeAttr = pim::getCheckedI32Attr( - rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size"); - if (failed(sizeAttr)) + DenseMap updatedOutputs; + for (const FragmentAssemblyCopyRun& run : fragmentAssemblyRunsByResult[resultIndex]) { + Value outputTensor = updatedOutputs.lookup(run.hostTargetIndex); + if (!outputTensor) + outputTensor = outputTensors[run.hostTargetIndex](rewriter, insertSlice.getLoc()); + FragmentAssemblyCopyRun mappedRun = run; + mappedRun.source = mappedSource; + FailureOr updated = + emitFragmentAssemblyCopyRuns(rewriter, + insertSlice.getLoc(), + ArrayRef {mappedRun}, + outputTensor, + coreBatchOp.getOperation(), + laneArg); + if (failed(updated)) return failure(); - Value hostTargetOffset = - createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc()); - Value deviceSourceOffset = getOrCreateIndexConstant( - rewriter, coreBatchOp.getOperation(), - plan.localSourceElementOffset * static_cast(getElementTypeSizeInBytes(mappedSourceType.getElementType()))); - outputTensor = - pim::PimMemCopyDevToHostOp::create(rewriter, - insertSlice.getLoc(), - outputTensor.getType(), - hostTargetOffset, - deviceSourceOffset, - outputTensor, - mappedSource, - *sizeAttr) - .getOutput(); + updatedOutputs[run.hostTargetIndex] = *updated; } continue; } @@ -657,11 +557,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul insertSlice.getSource().getDefiningOp()) { std::optional modeAttr = blueprint.getMode(); if (modeAttr && *modeAttr == "fragment_assembly") { - FailureOr updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter, - blueprint, - hostTarget, - insertSlice.getMixedOffsets(), - mapper); + FailureOr> fragmentAssemblyCopies = + collectFragmentAssemblyCopiesFromBlueprint(blueprint, mapper, /*lane=*/0, /*hostTargetIndex=*/0); + if (failed(fragmentAssemblyCopies)) + return failure(); + FailureOr> fragmentAssemblyRuns = + groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, /*laneCount=*/1); + if (failed(fragmentAssemblyRuns)) + return failure(); + SmallVector zeroOffsets(hostTargetType.getRank(), 0); + Value baseHostOffset = createHostTargetOffset(rewriter, + blueprint.getLoc(), + hostTargetType, + insertSlice.getMixedOffsets(), + zeroOffsets, + mapper); + FailureOr updatedHostTarget = emitFragmentAssemblyCopyRuns(rewriter, + blueprint.getLoc(), + *fragmentAssemblyRuns, + hostTarget, + coreBatchOp.getOperation(), + std::nullopt, + baseHostOffset); if (failed(updatedHostTarget)) return failure(); hostOutputTensors[resultIndex] = *updatedHostTarget; diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 38ee17c..f5859d2 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -1,12 +1,17 @@ #include "mlir/IR/ValueRange.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + #include "llvm/ADT/STLExtras.h" #include +#include #include "Common.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace llvm; using namespace mlir; @@ -186,4 +191,375 @@ forEachContiguousDestinationChunk(ArrayRef destShape, return visit(visit, 0); } +static mlir::Value +createSteppedOffset(OpBuilder& builder, Location loc, mlir::Value start, mlir::Value index, int64_t stepBytes) { + if (stepBytes == 0) + return start; + mlir::Value step = arith::ConstantIndexOp::create(builder, loc, stepBytes); + mlir::Value scaled = arith::MulIOp::create(builder, loc, index, step).getResult(); + return arith::AddIOp::create(builder, loc, start, scaled).getResult(); +} + +static mlir::Value createIndexedOffset(OpBuilder& builder, + Location loc, + mlir::Value indexArg, + ArrayRef values) { + assert(!values.empty() && "expected lane-indexed values"); + if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); })) + return arith::ConstantIndexOp::create(builder, loc, values.front()); + + if (values.size() >= 2) { + int64_t step = values[1] - values[0]; + bool arithmetic = llvm::all_of(llvm::seq(2, values.size()), [&](size_t index) { + return values[index] == values.front() + static_cast(index) * step; + }); + if (arithmetic) { + mlir::Value base = arith::ConstantIndexOp::create(builder, loc, values.front()); + mlir::Value stepValue = arith::ConstantIndexOp::create(builder, loc, step); + mlir::Value scaledIndex = arith::MulIOp::create(builder, loc, indexArg, stepValue).getResult(); + return arith::AddIOp::create(builder, loc, base, scaledIndex).getResult(); + } + } + + mlir::Value selected = arith::ConstantIndexOp::create(builder, loc, values.front()); + for (auto [lane, value] : llvm::enumerate(values.drop_front())) { + mlir::Value indexValue = arith::ConstantIndexOp::create(builder, loc, static_cast(lane + 1)); + mlir::Value cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, indexArg, indexValue); + mlir::Value candidate = arith::ConstantIndexOp::create(builder, loc, value); + selected = arith::SelectOp::create(builder, loc, cmp, candidate, selected); + } + return selected; +} + +struct FragmentAssemblyCopyRunFamily { + FragmentAssemblyCopyRun prototype; + SmallVector sourceRunStartDeltas; + SmallVector hostRunStartDeltas; +}; + +static bool computeUniformRunStartDelta(ArrayRef prototypeStarts, + ArrayRef runStarts, + int64_t& delta) { + if (prototypeStarts.size() != runStarts.size() || prototypeStarts.empty()) + return false; + + delta = runStarts.front() - prototypeStarts.front(); + return llvm::all_of(llvm::zip_equal(prototypeStarts, runStarts), [&](auto pair) { + auto [prototypeStart, runStart] = pair; + return runStart - prototypeStart == delta; + }); +} + +static bool canMergeFragmentAssemblyCopyRunIntoFamily(const FragmentAssemblyCopyRunFamily& family, + const FragmentAssemblyCopyRun& run, + int64_t& sourceRunStartDelta, + int64_t& hostRunStartDelta) { + const FragmentAssemblyCopyRun& prototype = family.prototype; + if (prototype.source != run.source || prototype.sourceType != run.sourceType + || prototype.hostTargetIndex != run.hostTargetIndex || prototype.count != run.count + || prototype.sourceStepBytes != run.sourceStepBytes || prototype.hostStepBytes != run.hostStepBytes + || prototype.byteSize != run.byteSize) + return false; + + if (!computeUniformRunStartDelta(prototype.sourceStartBytesByLane, run.sourceStartBytesByLane, sourceRunStartDelta)) + return false; + return computeUniformRunStartDelta(prototype.hostStartBytesByLane, run.hostStartBytesByLane, hostRunStartDelta); +} + +static SmallVector +groupFragmentAssemblyCopyRunFamilies(ArrayRef runs) { + auto compareRunStarts = [](ArrayRef lhs, ArrayRef rhs) { + return std::lexicographical_compare(lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); + }; + + SmallVector sortedRuns(runs.begin(), runs.end()); + llvm::sort(sortedRuns, [&](const FragmentAssemblyCopyRun& lhs, const FragmentAssemblyCopyRun& rhs) { + if (lhs.hostTargetIndex != rhs.hostTargetIndex) + return lhs.hostTargetIndex < rhs.hostTargetIndex; + if (lhs.source != rhs.source) + return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer(); + if (lhs.byteSize != rhs.byteSize) + return lhs.byteSize < rhs.byteSize; + if (lhs.count != rhs.count) + return lhs.count < rhs.count; + if (lhs.sourceStepBytes != rhs.sourceStepBytes) + return lhs.sourceStepBytes < rhs.sourceStepBytes; + if (lhs.hostStepBytes != rhs.hostStepBytes) + return lhs.hostStepBytes < rhs.hostStepBytes; + if (compareRunStarts(lhs.sourceStartBytesByLane, rhs.sourceStartBytesByLane)) + return true; + if (compareRunStarts(rhs.sourceStartBytesByLane, lhs.sourceStartBytesByLane)) + return false; + return compareRunStarts(lhs.hostStartBytesByLane, rhs.hostStartBytesByLane); + }); + + SmallVector families; + for (const FragmentAssemblyCopyRun& run : sortedRuns) { + int64_t sourceRunStartDelta = 0; + int64_t hostRunStartDelta = 0; + if (!families.empty() + && canMergeFragmentAssemblyCopyRunIntoFamily( + families.back(), run, sourceRunStartDelta, hostRunStartDelta)) { + families.back().sourceRunStartDeltas.push_back(sourceRunStartDelta); + families.back().hostRunStartDeltas.push_back(hostRunStartDelta); + continue; + } + + FragmentAssemblyCopyRunFamily family; + family.prototype = run; + family.sourceRunStartDeltas.push_back(0); + family.hostRunStartDeltas.push_back(0); + families.push_back(std::move(family)); + } + + return families; +} + +FailureOr> +groupFragmentAssemblyCopyRuns(ArrayRef copies, uint32_t laneCount) { + if (laneCount == 0) + return failure(); + + struct LaneLocalCopyRun { + FragmentAssemblyCopyRun run; + int64_t lane = 0; + }; + + SmallVector sortedCopies(copies.begin(), copies.end()); + llvm::sort(sortedCopies, [](const FragmentAssemblyCopy& lhs, const FragmentAssemblyCopy& rhs) { + if (lhs.hostTargetIndex != rhs.hostTargetIndex) + return lhs.hostTargetIndex < rhs.hostTargetIndex; + if (lhs.source != rhs.source) + return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer(); + if (lhs.lane != rhs.lane) + return lhs.lane < rhs.lane; + if (lhs.byteSize != rhs.byteSize) + return lhs.byteSize < rhs.byteSize; + if (lhs.sourceByteOffset != rhs.sourceByteOffset) + return lhs.sourceByteOffset < rhs.sourceByteOffset; + return lhs.hostByteOffset < rhs.hostByteOffset; + }); + + SmallVector laneRuns; + for (const FragmentAssemblyCopy& copy : sortedCopies) { + if (copy.lane < 0 || copy.lane >= static_cast(laneCount)) + return failure(); + + if (!laneRuns.empty()) { + LaneLocalCopyRun& laneRun = laneRuns.back(); + FragmentAssemblyCopyRun& run = laneRun.run; + if (run.source == copy.source && run.sourceType == copy.sourceType + && run.hostTargetIndex == copy.hostTargetIndex && laneRun.lane == copy.lane && run.byteSize == copy.byteSize + && run.sourceStartBytesByLane.size() == 1 && run.hostStartBytesByLane.size() == 1) { + int64_t previousSourceOffset = run.sourceStartBytesByLane.front() + (run.count - 1) * run.sourceStepBytes; + int64_t previousHostOffset = run.hostStartBytesByLane.front() + (run.count - 1) * run.hostStepBytes; + int64_t sourceDelta = copy.sourceByteOffset - previousSourceOffset; + int64_t hostDelta = copy.hostByteOffset - previousHostOffset; + if (run.count == 1) { + run.sourceStepBytes = sourceDelta; + run.hostStepBytes = hostDelta; + ++run.count; + continue; + } + if (run.sourceStepBytes == sourceDelta && run.hostStepBytes == hostDelta) { + ++run.count; + continue; + } + } + } + + LaneLocalCopyRun laneRun; + laneRun.run.source = copy.source; + laneRun.run.sourceType = copy.sourceType; + laneRun.run.hostTargetIndex = copy.hostTargetIndex; + laneRun.run.count = 1; + laneRun.run.byteSize = copy.byteSize; + laneRun.run.sourceStartBytesByLane.push_back(copy.sourceByteOffset); + laneRun.run.hostStartBytesByLane.push_back(copy.hostByteOffset); + laneRun.lane = copy.lane; + laneRuns.push_back(std::move(laneRun)); + } + + SmallVector mergedRuns; + for (const LaneLocalCopyRun& laneRun : laneRuns) { + size_t laneIndex = static_cast(laneRun.lane); + auto mergedIt = llvm::find_if(mergedRuns, [&](const FragmentAssemblyCopyRun& run) { + return run.source == laneRun.run.source && run.sourceType == laneRun.run.sourceType + && run.hostTargetIndex == laneRun.run.hostTargetIndex && run.count == laneRun.run.count + && run.byteSize == laneRun.run.byteSize && run.sourceStepBytes == laneRun.run.sourceStepBytes + && run.hostStepBytes == laneRun.run.hostStepBytes && laneIndex < run.sourceStartBytesByLane.size() + && run.sourceStartBytesByLane[laneIndex] == std::numeric_limits::min(); + }); + + if (mergedIt == mergedRuns.end()) { + FragmentAssemblyCopyRun merged = laneRun.run; + merged.sourceStartBytesByLane.assign(laneCount, std::numeric_limits::min()); + merged.hostStartBytesByLane.assign(laneCount, std::numeric_limits::min()); + merged.sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front(); + merged.hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front(); + mergedRuns.push_back(std::move(merged)); + continue; + } + + mergedIt->sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front(); + mergedIt->hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front(); + } + + for (const FragmentAssemblyCopyRun& run : mergedRuns) { + if (llvm::any_of(run.sourceStartBytesByLane, + [](int64_t value) { return value == std::numeric_limits::min(); })) + return failure(); + if (llvm::any_of(run.hostStartBytesByLane, + [](int64_t value) { return value == std::numeric_limits::min(); })) + return failure(); + } + + return mergedRuns; +} + +static FailureOr emitFragmentAssemblyCopyRun(OpBuilder& builder, + Location loc, + const FragmentAssemblyCopyRun& run, + mlir::Value hostTarget, + Operation* anchor, + std::optional laneArg, + mlir::Value baseHostOffset, + mlir::Value sourceRunStartDelta = {}, + mlir::Value hostRunStartDelta = {}) { + auto sizeAttr = pim::getCheckedI32Attr(builder, anchor, run.byteSize, "fragment assembly host copy byte size"); + if (failed(sizeAttr)) + return failure(); + + mlir::Value hostStart; + mlir::Value sourceStart; + if (laneArg) { + hostStart = createIndexedOffset(builder, loc, *laneArg, run.hostStartBytesByLane); + sourceStart = createIndexedOffset(builder, loc, *laneArg, run.sourceStartBytesByLane); + } else { + hostStart = arith::ConstantIndexOp::create(builder, loc, run.hostStartBytesByLane.front()); + sourceStart = arith::ConstantIndexOp::create(builder, loc, run.sourceStartBytesByLane.front()); + } + + if (hostRunStartDelta) + hostStart = arith::AddIOp::create(builder, loc, hostStart, hostRunStartDelta).getResult(); + if (sourceRunStartDelta) + sourceStart = arith::AddIOp::create(builder, loc, sourceStart, sourceRunStartDelta).getResult(); + if (baseHostOffset) + hostStart = arith::AddIOp::create(builder, loc, baseHostOffset, hostStart).getResult(); + + if (run.count == 1) { + return pim::PimMemCopyDevToHostOp::create(builder, + loc, + hostTarget.getType(), + hostStart, + sourceStart, + hostTarget, + run.source, + *sizeAttr) + .getOutput(); + } + + mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0); + mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, run.count); + mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1); + FailureOr loop = buildNormalizedScfFor( + builder, + loc, + lowerBound, + upperBound, + step, + ValueRange {hostTarget}, + [&](OpBuilder& loopBuilder, + Location bodyLoc, + mlir::Value flatIndex, + ValueRange iterArgs, + SmallVectorImpl& yielded) { + mlir::Value hostOffset = createSteppedOffset(loopBuilder, bodyLoc, hostStart, flatIndex, run.hostStepBytes); + mlir::Value sourceOffset = + createSteppedOffset(loopBuilder, bodyLoc, sourceStart, flatIndex, run.sourceStepBytes); + mlir::Value copied = + pim::PimMemCopyDevToHostOp::create(loopBuilder, + bodyLoc, + iterArgs.front().getType(), + hostOffset, + sourceOffset, + iterArgs.front(), + run.source, + *sizeAttr) + .getOutput(); + yielded.push_back(copied); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +static FailureOr emitFragmentAssemblyCopyRunFamily(OpBuilder& builder, + Location loc, + const FragmentAssemblyCopyRunFamily& family, + mlir::Value hostTarget, + Operation* anchor, + std::optional laneArg, + mlir::Value baseHostOffset) { + if (family.sourceRunStartDeltas.size() == 1) + return emitFragmentAssemblyCopyRun( + builder, loc, family.prototype, hostTarget, anchor, laneArg, baseHostOffset); + + mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0); + mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, family.sourceRunStartDeltas.size()); + mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1); + FailureOr outerLoop = buildNormalizedScfFor( + builder, + loc, + lowerBound, + upperBound, + step, + ValueRange {hostTarget}, + [&](OpBuilder& loopBuilder, + Location bodyLoc, + mlir::Value runIndex, + ValueRange iterArgs, + SmallVectorImpl& yielded) { + mlir::Value sourceRunStartDelta = + createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.sourceRunStartDeltas); + mlir::Value hostRunStartDelta = + createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.hostRunStartDeltas); + FailureOr copied = emitFragmentAssemblyCopyRun(loopBuilder, + bodyLoc, + family.prototype, + iterArgs.front(), + anchor, + laneArg, + baseHostOffset, + sourceRunStartDelta, + hostRunStartDelta); + if (failed(copied)) + return failure(); + yielded.push_back(*copied); + return success(); + }); + if (failed(outerLoop)) + return failure(); + return outerLoop->results.front(); +} + +FailureOr emitFragmentAssemblyCopyRuns(IRRewriter& rewriter, + Location loc, + ArrayRef runs, + mlir::Value hostTarget, + Operation* anchor, + std::optional laneArg, + mlir::Value baseHostOffset) { + for (const FragmentAssemblyCopyRunFamily& family : groupFragmentAssemblyCopyRunFamilies(runs)) { + FailureOr updatedHostTarget = + emitFragmentAssemblyCopyRunFamily(rewriter, loc, family, hostTarget, anchor, laneArg, baseHostOffset); + if (failed(updatedHostTarget)) + return failure(); + hostTarget = *updatedHostTarget; + } + + return hostTarget; +} + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index cc4a1ef..ea3317b 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -1,8 +1,14 @@ #pragma once +#include + #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/STLFunctionalExtras.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Support/LogicalResult.h" @@ -59,6 +65,39 @@ forEachContiguousDestinationChunk(llvm::ArrayRef destShape, llvm::function_ref, int64_t, int64_t)> callback); +struct FragmentAssemblyCopy { + mlir::Value source; + mlir::RankedTensorType sourceType; + unsigned hostTargetIndex = 0; + int64_t lane = 0; + int64_t sourceByteOffset = 0; + int64_t hostByteOffset = 0; + int64_t byteSize = 0; +}; + +struct FragmentAssemblyCopyRun { + mlir::Value source; + mlir::RankedTensorType sourceType; + unsigned hostTargetIndex = 0; + int64_t count = 0; + int64_t sourceStepBytes = 0; + int64_t hostStepBytes = 0; + int64_t byteSize = 0; + mlir::SmallVector sourceStartBytesByLane; + mlir::SmallVector hostStartBytesByLane; +}; + +mlir::FailureOr> +groupFragmentAssemblyCopyRuns(llvm::ArrayRef copies, uint32_t laneCount = 1); + +mlir::FailureOr emitFragmentAssemblyCopyRuns(mlir::IRRewriter& rewriter, + mlir::Location loc, + llvm::ArrayRef runs, + mlir::Value hostTarget, + mlir::Operation* anchor, + std::optional laneArg = std::nullopt, + mlir::Value baseHostOffset = {}); + inline mlir::tensor::EmptyOp createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType()); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 25521af..e4300ea 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/IRMapping.h" @@ -8,6 +9,7 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" @@ -30,22 +32,9 @@ static bool isChannelUseChainOp(Operation* op) { pim::PimTransposeOp>(op); } -static Value createStaticHostTargetOffset(IRRewriter& rewriter, - Location loc, - ShapedType destinationType, - ArrayRef fragmentOffsets) { - int64_t elementBytes = static_cast(getElementTypeSizeInBytes(destinationType.getElementType())); - SmallVector strides = computeRowMajorStrides(destinationType.getShape()); - - int64_t byteOffset = 0; - for (auto [dim, offset] : llvm::enumerate(fragmentOffsets)) - byteOffset += offset * strides[dim] * elementBytes; - return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset); -} - static FailureOr lowerFragmentAssemblyBlueprint(IRRewriter& rewriter, - spatial::SpatBlueprintOp blueprint, - IRMapping& mapping) { + spatial::SpatBlueprintOp blueprint, + IRMapping& mapping) { auto resultType = dyn_cast(blueprint.getOutput().getType()); if (!resultType || !resultType.hasStaticShape()) return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result"); @@ -77,56 +66,54 @@ static FailureOr lowerFragmentAssemblyBlueprint(IRRewriter& rewriter, flatStrides))) return failure(); - Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType); + SmallVector hostStrides = computeRowMajorStrides(resultType.getShape()); + SmallVector copies; for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; SmallVector fragmentOffsets; - int64_t fragmentElements = 1; + SmallVector fragmentSizes; for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) return blueprint.emitOpError("fragment assembly lowering only supports unit strides"); fragmentOffsets.push_back(flatOffsets[flatIndex]); - fragmentElements *= flatSizes[flatIndex]; + fragmentSizes.push_back(flatSizes[flatIndex]); } Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]); - auto sourceType = dyn_cast(source.getType()); + auto sourceType = dyn_cast(source.getType()); if (!sourceType || !sourceType.hasStaticShape()) return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType()); + if (failed(forEachContiguousDestinationChunk( + resultType.getShape(), + fragmentOffsets, + fragmentSizes, + [&](ArrayRef chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { + int64_t hostElementOffset = 0; + for (auto [dim, offset] : llvm::enumerate(chunkOffsets)) + hostElementOffset += offset * hostStrides[dim]; - int64_t fragmentBytes = - fragmentElements * static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); - auto sizeAttr = pim::getCheckedI32Attr(rewriter, - blueprint.getOperation(), - fragmentBytes, - "fragment assembly host copy size"); - if (failed(sizeAttr)) + FragmentAssemblyCopy copy; + copy.source = source; + copy.sourceType = sourceType; + copy.sourceByteOffset = + (sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast(elementSize); + copy.hostByteOffset = hostElementOffset * static_cast(elementSize); + copy.byteSize = chunkElements * static_cast(elementSize); + copies.push_back(copy); + return success(); + }))) return failure(); - - Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets); - auto deviceSourceOffsetBytes = pim::checkedMul(static_cast(sourceOffsets[fragmentIndex]), - static_cast(getElementTypeSizeInBytes(sourceType.getElementType())), - blueprint, - "fragment assembly device source offset"); - if (failed(deviceSourceOffsetBytes)) - return failure(); - Value deviceSourceOffset = getOrCreateIndexConstant(rewriter, - rewriter.getInsertionBlock()->getParentOp(), - static_cast(*deviceSourceOffsetBytes)); - currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter, - blueprint.getLoc(), - currentOutput.getType(), - hostTargetOffset, - deviceSourceOffset, - currentOutput, - source, - *sizeAttr) - .getOutput(); } - return currentOutput; + Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType); + FailureOr> runs = groupFragmentAssemblyCopyRuns(copies); + if (failed(runs)) + return failure(); + return emitFragmentAssemblyCopyRuns( + rewriter, blueprint.getLoc(), *runs, currentOutput, blueprint.getOperation()); } static void diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 35dda7a..718921a 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinOps.h" @@ -11,6 +12,7 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" @@ -638,6 +640,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low flatSizes, flatStrides))) return ReturnPathLoweringResult::Failure; + SmallVector copies; for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { if (operandIndices[fragmentIndex] != static_cast(operandNumber)) continue; @@ -675,29 +678,27 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low failedChunk = true; return failure(); } - auto sizeAttr = - pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size"); - if (failed(sizeAttr)) { - failedChunk = true; - return failure(); - } - - outputTensor = - pim::PimMemCopyDevToHostOp::create(rewriter, - blueprint.getLoc(), - outputTensor.getType(), - getOrCreateIndexConstant(rewriter, producerOp, *hostOffset), - getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset), - outputTensor, - storedValue, - *sizeAttr) - .getOutput(); + FragmentAssemblyCopy copy; + copy.source = storedValue; + copy.sourceType = sourceType; + copy.hostByteOffset = *hostOffset; + copy.sourceByteOffset = *sourceOffset; + copy.byteSize = *fragmentBytes; + copies.push_back(copy); return success(); }))) failedChunk = true; if (failedChunk) return ReturnPathLoweringResult::Failure; } + FailureOr> runs = groupFragmentAssemblyCopyRuns(copies); + if (failed(runs)) + return ReturnPathLoweringResult::Failure; + FailureOr updatedOutput = + emitFragmentAssemblyCopyRuns(rewriter, blueprint.getLoc(), *runs, outputTensor, producerOp); + if (failed(updatedOutput)) + return ReturnPathLoweringResult::Failure; + outputTensor = *updatedOutput; markOpToRemove(blueprint.getOperation()); } return ReturnPathLoweringResult::Handled; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 8fe4cd1..31219e5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -228,6 +228,7 @@ struct IndexedBatchRunValue { ClassId targetClass = 0; Operation* sourceOp = nullptr; size_t resultIndex = 0; + Value packed; RankedTensorType fragmentType; SmallVector slots; MessageVector messages; @@ -290,10 +291,16 @@ struct ProjectedFragmentLayout { SmallVector loopTripCounts; }; +struct StaticProjectedLoopInfo { + BlockArgument iv; + int64_t lowerBound = 0; + int64_t step = 1; + int64_t tripCount = 1; +}; + struct ProjectedTransferDescriptor { ProjectedBatchInputKey inputKey; Operation* extractOp = nullptr; - ProjectedFragmentLayout layout; RankedTensorType payloadType; SmallVector, 16> fragmentOffsets; @@ -319,16 +326,44 @@ struct PendingProjectedHostOutputFragment { Location loc; }; -struct CloneIndexingContext { - std::optional runSlotIndex; - std::optional projectionSlotIndex; +enum class TensorDemandActionKind { + DestinationFanout, + SameClassIndexedFragment, + TerminalBlueprintPublication, + WholeTensorBarrier }; -struct StaticProjectedLoopInfo { - BlockArgument iv; - int64_t lowerBound = 0; - int64_t step = 1; - int64_t tripCount = 1; +enum class WholeTensorBarrierReason { + FunctionReturnWithoutBlueprint, + DenseLogicalConsumer +}; + +struct TensorDemandAction { + TensorDemandActionKind kind = TensorDemandActionKind::DestinationFanout; + std::optional destinationClass; + std::optional barrierReason; +}; + +struct RunOutputDemand { + size_t resultIndex = 0; + Value originalOutput; + RankedTensorType fragmentType; + SmallVector actions; +}; + +struct CompactRunPlan { + SmallVector outputs; +}; + +enum class BatchInputDemandKind { + LaneFragment, + ProjectedFragment, + WholeTensorBarrier +}; + +struct BatchInputDemand { + BatchInputDemandKind kind = BatchInputDemandKind::LaneFragment; + std::optional wholeTensorProducer; }; struct AffineProjectedInputSliceMatch { @@ -340,6 +375,11 @@ struct AffineProjectedInputSliceMatch { SmallVector loops; }; +struct CloneIndexingContext { + std::optional runSlotIndex; + std::optional projectionSlotIndex; +}; + struct MaterializerState; FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerState& state, MaterializedClass& sourceClass, @@ -367,6 +407,11 @@ FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState StringRef context, std::optional producer = std::nullopt, IRMapping* mapper = nullptr); +FailureOr materializeWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType, + Location loc); FailureOr localizeMaterializedClassOperand(MaterializerState& state, MaterializedClass& targetClass, Value value, @@ -389,6 +434,16 @@ bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consu std::optional getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex); +std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex); +std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, + tensor::ExtractSliceOp extract); +FailureOr materializeProjectedWholeBatchExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + ProducerKey producer, + IRMapping* mapper = nullptr); class AvailableValueStore { public: @@ -822,6 +877,47 @@ bool canUseProjectedLaneInput(MaterializerState& state, consumerBatch, *match, laneProducer, logicalConsumer.laneStart); } +FailureOr classifyComputeBatchInputDemand(MaterializerState& state, + MaterializedClass& targetClass, + SpatComputeBatch consumerBatch, + unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer) { + if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input)) { + if (canUseProjectedLaneInput(state, consumerBatch, inputIndex, input, logicalConsumer)) + return BatchInputDemand { + .kind = BatchInputDemandKind::ProjectedFragment, .wholeTensorProducer = std::nullopt}; + + if (getProjectedWholeBatchReplacementProducer(state, consumerBatch, inputIndex)) + return BatchInputDemand { + .kind = BatchInputDemandKind::ProjectedFragment, .wholeTensorProducer = std::nullopt}; + + auto inputArg = consumerBatch.getInputArgument(inputIndex); + if (!inputArg) + return consumerBatch.emitOpError("expected compute_batch input block argument while classifying input demand") + << " #" << inputIndex; + + bool hasUses = false; + for (OpOperand& use : inputArg->getUses()) { + hasUses = true; + if (!isa(use.getOwner())) + return BatchInputDemand { + .kind = BatchInputDemandKind::WholeTensorBarrier, .wholeTensorProducer = wholeBatchProducer}; + } + if (!hasUses) + return BatchInputDemand {.kind = BatchInputDemandKind::LaneFragment, .wholeTensorProducer = std::nullopt}; + + return targetClass.op->emitError("failed to classify compute_batch input demand") + << " reason=direct whole-batch input only has projected uses, but no projected fragment path was proven" + << " consumerOp='" << consumerBatch->getName() << "' inputIndex=" << inputIndex + << " producerOp='" << wholeBatchProducer->instance.op->getName() << "' resultIndex=" + << wholeBatchProducer->resultIndex << " sourceClass=" << targetClass.id + << " valueType=" << input.getType(); + } + + return BatchInputDemand {.kind = BatchInputDemandKind::LaneFragment, .wholeTensorProducer = std::nullopt}; +} + SmallVector collectProducerKeysForBatchInputDestinations(MaterializerState& state, SpatComputeBatch consumerBatch, unsigned inputIndex, @@ -1300,9 +1396,10 @@ FailureOr appendBatchPublicationResult(MaterializerState& state, state.rewriter.eraseOp(yieldOp); } - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + state.rewriter.setInsertionPoint(inParallelOp); Value firstOffset = scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc); + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset); return result.getResultNumber(); } @@ -1340,14 +1437,29 @@ Operation* getEnclosingSpatialComputeLikeOp(Value value) { return nullptr; } -bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { - if (!isa(value.getType())) - return true; +static bool isValueDefinedInMaterializedClass(Value value, const MaterializedClass& targetClass) { if (isConstantLike(value)) return true; - Region& targetRegion = *targetClass.body->getParent(); - return isDefinedInsideRegion(value, targetRegion); + if (auto blockArg = dyn_cast(value)) { + for (Operation* current = blockArg.getOwner()->getParentOp(); current; current = current->getParentOp()) + if (current == targetClass.op) + return true; + return false; + } + + if (Operation* definingOp = value.getDefiningOp()) + for (Operation* current = definingOp; current; current = current->getParentOp()) + if (current == targetClass.op) + return true; + + return false; +} + +bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType())) + return true; + return isValueDefinedInMaterializedClass(value, targetClass); } bool isTensorValueDefinedInDifferentMaterializedClass(Value value, const MaterializedClass& targetClass) { @@ -1391,11 +1503,7 @@ Block* getBlockByIndex(Region& region, unsigned blockIndex) { } static bool isValueLegalInMaterializedClassBody(Value value, const MaterializedClass& targetClass) { - if (isConstantLike(value)) - return true; - - Region& targetRegion = *targetClass.body->getParent(); - return isDefinedInsideRegion(value, targetRegion); + return isValueDefinedInMaterializedClass(value, targetClass); } std::string stringifyOperationForMaterializerDebug(Operation* op) { @@ -1785,6 +1893,25 @@ FailureOr rematerializeTensorValueInClass(MaterializerState& state, .getResult(); } + if (auto empty = value.getDefiningOp()) { + auto emptyType = dyn_cast(empty.getType()); + if (!emptyType) + return anchor->emitError("expected ranked tensor.empty while rematerializing tensor capture"); + + SmallVector dynamicSizes; + dynamicSizes.reserve(empty.getDynamicSizes().size()); + for (Value dynamicSize : empty.getDynamicSizes()) { + FailureOr localizedSize = + rematerializeIndexValueInClass(state, targetClass, dynamicSize, anchor->getLoc(), mapper); + if (failed(localizedSize)) + return failure(); + dynamicSizes.push_back(*localizedSize); + } + return tensor::EmptyOp::create( + state.rewriter, anchor->getLoc(), emptyType.getShape(), emptyType.getElementType(), dynamicSizes) + .getResult(); + } + return failure(); } @@ -1801,7 +1928,8 @@ FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) return value; - if (value.getDefiningOp() || value.getDefiningOp()) { + if (value.getDefiningOp() || value.getDefiningOp() + || value.getDefiningOp()) { FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); if (failed(rematerialized)) return failure(); @@ -2179,6 +2307,16 @@ Value getPackedSliceForRunIndex(MaterializerState& state, return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } +Value getPackedSliceForDynamicRunIndex(MaterializerState& state, + Operation* anchor, + Value packed, + RankedTensorType fragmentType, + Value index, + Location loc) { + Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); + return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); +} + FailureOr createReceiveConcatLoop(MaterializerState& state, MaterializedClass& targetClass, RankedTensorType concatType, @@ -3079,6 +3217,60 @@ getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, un return match; } +std::optional +getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex) { + std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); + if (!match) + return std::nullopt; + + Value input = batch.getInputs()[inputIndex]; + std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input); + if (!wholeBatchProducer) + return std::nullopt; + + if (canUseProjectedLaneInput( + state, + batch, + inputIndex, + input, + ComputeInstance {batch.getOperation(), 0, 1})) { + return std::nullopt; + } + + auto producerBatch = dyn_cast_or_null(wholeBatchProducer->instance.op); + if (!producerBatch) + return std::nullopt; + + if (failed(getBatchResultProjectionInsert(producerBatch, wholeBatchProducer->resultIndex))) + return std::nullopt; + + return wholeBatchProducer; +} + +std::optional +getProjectedWholeBatchReplacementProducer(MaterializerState& state, tensor::ExtractSliceOp extract) { + auto sourceArg = dyn_cast(extract.getSource()); + if (!sourceArg) + return std::nullopt; + + auto batch = dyn_cast_or_null(sourceArg.getOwner()->getParentOp()); + if (!batch) + return std::nullopt; + + for (unsigned inputIndex = 0; inputIndex < batch.getInputs().size(); ++inputIndex) { + std::optional inputArg = batch.getInputArgument(inputIndex); + if (!inputArg || *inputArg != sourceArg) + continue; + + std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); + if (!match || match->extract != extract) + return std::nullopt; + return getProjectedWholeBatchReplacementProducer(state, batch, inputIndex); + } + + return std::nullopt; +} + FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { @@ -3511,10 +3703,10 @@ collectScalarTargetProjectedDescriptor(MaterializerState& state, if (!(combined->inputKey == descriptor.inputKey) || combined->extractOp != descriptor.extractOp || combined->layout.fragmentType != descriptor.layout.fragmentType || combined->layout.fragmentShape != descriptor.layout.fragmentShape + || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot || combined->layout.loopLowerBounds != descriptor.layout.loopLowerBounds || combined->layout.loopSteps != descriptor.layout.loopSteps - || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts - || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot) + || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts) return std::nullopt; combined->layout.payloadFragmentCount += descriptor.layout.payloadFragmentCount; @@ -4062,6 +4254,8 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& }); if (groupIt == fanoutPlan.projectedSendGroups.end()) { ProjectedScalarSendGroup group; + group.descriptor.inputKey = projectedDescriptor.inputKey; + group.descriptor.extractOp = projectedDescriptor.extractOp; group.descriptor.layout = projectedDescriptor.layout; group.descriptor.payloadType = projectedDescriptor.payloadType; fanoutPlan.projectedSendGroups.push_back(std::move(group)); @@ -5698,13 +5892,23 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt return failure(); if (packedType != fragmentType) { - if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast(keys.size()) != 0) + size_t keysPerPublishedPayload = keys.size(); + if (sourceClass.isBatch) { + if (sourceClass.cpus.empty() || keys.size() % sourceClass.cpus.size() != 0) + return sourceClass.op->emitError( + "projected packed host publication requires a stable per-lane key partition") + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size() + << " laneCount=" << sourceClass.cpus.size(); + keysPerPublishedPayload = keys.size() / sourceClass.cpus.size(); + } + + if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast(keysPerPublishedPayload) != 0) return sourceClass.op->emitError( "projected packed host publication requires either direct fragment operands or evenly dim-0 packed fragments") << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); SmallVector packedFragmentShape(packedType.getShape()); - packedFragmentShape[0] /= static_cast(keys.size()); + packedFragmentShape[0] /= static_cast(keysPerPublishedPayload); if (packedFragmentShape != fragmentShape) return sourceClass.op->emitError( "projected packed host publication fragment shape does not match projected slice size") @@ -5714,11 +5918,16 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt int64_t payloadElementCount = packedType.getNumElements(); int64_t fragmentElementCount = fragmentType.getNumElements(); int64_t fragmentsPerPublishedPayload = payloadElementCount / fragmentElementCount; - if (fragmentsPerPublishedPayload <= 0 || static_cast(keys.size()) % fragmentsPerPublishedPayload != 0) + size_t keysPerPublishedPayload = keys.size(); + if (sourceClass.isBatch) + keysPerPublishedPayload /= sourceClass.cpus.size(); + if (fragmentsPerPublishedPayload <= 0 + || static_cast(keysPerPublishedPayload) % fragmentsPerPublishedPayload != 0) return sourceClass.op->emitOpError( "projected packed host publication requires a deterministic publication packing layout") << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + DenseMap publishedFragmentOrdinals; for (auto [fragmentIndex, key] : llvm::enumerate(keys)) { if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex || key.instance.laneCount != 1) return sourceClass.op->emitError("projected packed host publication requires one-lane keys from one producer result"); @@ -5739,8 +5948,10 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt FailureOr publishedLaneIndex = getPublicationLaneForProducerKey(state, sourceClass, key); if (failed(publishedLaneIndex)) return failure(); + int64_t ordinalWithinPublishedPayload = + sourceClass.isBatch ? publishedFragmentOrdinals[*publishedLaneIndex]++ : static_cast(fragmentIndex); int64_t localFragmentOffsetWithinPublishedPayload = - (static_cast(fragmentIndex) % fragmentsPerPublishedPayload) * fragmentElementCount; + (ordinalWithinPublishedPayload % fragmentsPerPublishedPayload) * fragmentElementCount; state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { originalOutput, @@ -5880,7 +6091,8 @@ FailureOr resolveInputValue(MaterializerState& state, MaterializedClass& targetClass, Value input, const ComputeInstance& consumerInstance, - CloneIndexingContext indexing) { + CloneIndexingContext indexing, + bool allowWholeBatchFallback = true) { auto rejectNonLocalResolvedValue = [&](Value resolved) -> FailureOr { if (!isTensorValueDefinedInDifferentMaterializedClass(resolved, targetClass)) return resolved; @@ -5918,9 +6130,16 @@ FailureOr resolveInputValue(MaterializerState& state, if (!llvm::is_contained(slot.keys, *producer)) continue; - MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); - Value received = - appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); + Value received = Value(); + if (indexedRun->packed) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + received = getPackedSliceForRunIndex( + state, targetClass.op, indexedRun->packed, indexedRun->fragmentType, slotIndex, consumerInstance.op->getLoc()); + } else { + MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); + received = + appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); + } for (ProducerKey slotKey : slot.keys) state.availableValues.record(slotKey, targetClass.id, received); return rejectNonLocalResolvedValue(received); @@ -5928,6 +6147,12 @@ FailureOr resolveInputValue(MaterializerState& state, } if (isWholeBatchProducerKey(*producer)) { + if (!allowWholeBatchFallback) { + consumerInstance.op->emitError("failed to resolve compute_batch input without a compact fragment path") + << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + return failure(); + } FailureOr wholeBatch = materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); if (failed(wholeBatch)) @@ -5966,10 +6191,11 @@ bool hasProjectedInputReplacement(MaterializerState& state, return false; auto replacementIt = state.projectedExtractReplacements.find(match->extract.getOperation()); - if (replacementIt == state.projectedExtractReplacements.end()) - return false; + if (replacementIt != state.projectedExtractReplacements.end() + && replacementIt->second.find(classId) != replacementIt->second.end()) + return true; - return replacementIt->second.find(classId) != replacementIt->second.end(); + return getProjectedWholeBatchReplacementProducer(state, batch, inputIndex).has_value(); } void mapWeights(MaterializerState& state, @@ -6043,19 +6269,25 @@ LogicalResult mapInputs(MaterializerState& state, if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) continue; + FailureOr demand = + classifyComputeBatchInputDemand(state, targetClass, batch, static_cast(index), input, instance); + if (failed(demand)) + return batch.emitOpError("failed to classify materialized compute_batch input") << " #" << index; + FailureOr mapped = failure(); - if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input); - wholeBatchProducer && !canUseProjectedLaneInput(state, batch, static_cast(index), input, instance)) { + if (demand->kind == BatchInputDemandKind::WholeTensorBarrier) { + assert(demand->wholeTensorProducer && "whole-tensor input demand must carry a producer"); mapped = materializeWholeBatchInput( - state, targetClass, *wholeBatchProducer, input.getType(), batch.getOperation()->getLoc()); + state, targetClass, *demand->wholeTensorProducer, input.getType(), batch.getOperation()->getLoc()); if (failed(mapped)) return batch.emitOpError("failed to materialize whole-batch compute_batch input") - << " #" << index << " from '" << wholeBatchProducer->instance.op->getName() - << "' laneStart=" << wholeBatchProducer->instance.laneStart - << " laneCount=" << wholeBatchProducer->instance.laneCount - << " resultIndex=" << wholeBatchProducer->resultIndex; + << " #" << index << " from '" << demand->wholeTensorProducer->instance.op->getName() + << "' laneStart=" << demand->wholeTensorProducer->instance.laneStart + << " laneCount=" << demand->wholeTensorProducer->instance.laneCount + << " resultIndex=" << demand->wholeTensorProducer->resultIndex; } else { - mapped = resolveInputValue(state, targetClass, input, instance, indexing); + mapped = resolveInputValue( + state, targetClass, input, instance, indexing, /*allowWholeBatchFallback=*/false); if (failed(mapped)) return batch.emitOpError("failed to resolve materialized compute_batch input"); } @@ -6161,6 +6393,59 @@ std::optional lookupProjectedExtractReplacement(Mat return classIt->second; } +FailureOr materializeProjectedWholeBatchExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + ProducerKey producer, + IRMapping* mapper) { + FailureOr fullSource = + materializeWholeBatchInput(state, targetClass, producer, extract.getSource().getType(), extract.getLoc()); + if (failed(fullSource)) + return failure(); + + auto remapFoldResult = [&](OpFoldResult value) -> FailureOr { + if (auto mappedValue = dyn_cast_if_present(value)) { + FailureOr localized = + rematerializeIndexValueInClass(state, targetClass, mapper ? mapper->lookupOrDefault(mappedValue) : mappedValue, + extract.getLoc(), mapper); + if (failed(localized)) + return failure(); + return OpFoldResult(*localized); + } + return value; + }; + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(extract.getMixedOffsets().size()); + sizes.reserve(extract.getMixedSizes().size()); + strides.reserve(extract.getMixedStrides().size()); + + for (OpFoldResult value : extract.getMixedOffsets()) { + FailureOr localized = remapFoldResult(value); + if (failed(localized)) + return failure(); + offsets.push_back(*localized); + } + for (OpFoldResult value : extract.getMixedSizes()) { + FailureOr localized = remapFoldResult(value); + if (failed(localized)) + return failure(); + sizes.push_back(*localized); + } + for (OpFoldResult value : extract.getMixedStrides()) { + FailureOr localized = remapFoldResult(value); + if (failed(localized)) + return failure(); + strides.push_back(*localized); + } + + auto resultType = cast(extract.getType()); + return extractStaticSliceOrIdentity( + state.rewriter, extract.getLoc(), *fullSource, resultType, offsets, sizes, strides); +} + LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, MaterializedClass& targetClass, Operation& originalOp, @@ -6184,6 +6469,23 @@ LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& sta state.rewriter.eraseOp(clonedExtract); return success(); } + + if (std::optional producer = + getProjectedWholeBatchReplacementProducer(state, originalExtract)) { + auto clonedExtract = dyn_cast(&clonedOp); + if (!clonedExtract) + return targetClass.op->emitError("projected whole-batch replacement lost extract structure during cloning"); + + state.rewriter.setInsertionPoint(clonedExtract); + FailureOr projected = materializeProjectedWholeBatchExtractReplacement( + state, targetClass, clonedExtract, *producer, &mapper); + if (failed(projected)) + return failure(); + + clonedExtract.getResult().replaceAllUsesWith(*projected); + state.rewriter.eraseOp(clonedExtract); + return success(); + } } if (originalOp.getNumRegions() != clonedOp.getNumRegions()) @@ -6264,6 +6566,16 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, mapper.map(extract.getResult(), *projected); continue; } + + if (std::optional producer = getProjectedWholeBatchReplacementProducer(state, extract)) { + FailureOr projected = + materializeProjectedWholeBatchExtractReplacement(state, targetClass, extract, *producer, &mapper); + if (failed(projected)) + return failure(); + + mapper.map(extract.getResult(), *projected); + continue; + } } for (Value operand : op.getOperands()) { @@ -6418,7 +6730,9 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d1 = getAffineDimExpr(1, context); AffineMap packedIndexMap = AffineMap::get( - /*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.fragmentsPerLogicalSlot + d1); + /*dimCount=*/2, + /*symbolCount=*/0, + d0 * replacement.layout.fragmentsPerLogicalSlot + d1); return createOrFoldAffineApply(state.rewriter, extract.getLoc(), packedIndexMap, @@ -6445,6 +6759,8 @@ FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, Location loc) { if (!targetClass.isBatch) return targetClass.op->emitError("indexed batch run receive requires a batch target class"); + if (run.packed) + return getPackedSliceForDynamicRunIndex(state, targetClass.op, run.packed, run.fragmentType, runSlotIndex, loc); if (failed(run.messages.verify(targetClass.op))) return failure(); @@ -7077,6 +7393,129 @@ bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, return false; } +bool hasMaterializationRunResultSameClassConsumer(MaterializerState& state, + ClassId classId, + ArrayRef run, + size_t resultIndex) { + for (const MaterializationRunSlot& slot : run) + for (const ComputeInstance& peer : slot.peers) + if (hasSameClassConsumer(state, {peer, resultIndex}, classId)) + return true; + return false; +} + +StringRef describeWholeTensorBarrierReason(WholeTensorBarrierReason reason) { + switch (reason) { + case WholeTensorBarrierReason::FunctionReturnWithoutBlueprint: + return "function return or external use without spat.blueprint assembly"; + case WholeTensorBarrierReason::DenseLogicalConsumer: + return "consumer requires a dense logical tensor"; + } + llvm_unreachable("unknown whole-tensor barrier reason"); +} + +FailureOr classifyRunOutputDemand(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + ArrayRef destinationClasses, + size_t resultIndex) { + auto sourceBatch = dyn_cast(getMaterializationRunSourceOp(run)); + if (!sourceBatch) + return targetClass.op->emitError("compact batch demand classification expects a compute_batch source"); + + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); + if (resultIndex >= firstOriginalOutputs.size() || resultIndex >= fragmentTypes.size()) + return targetClass.op->emitError("compact batch demand classification found an invalid output index") + << " resultIndex=" << resultIndex; + + auto fragmentType = dyn_cast(fragmentTypes[resultIndex]); + if (!fragmentType || !fragmentType.hasStaticShape() || fragmentType.getRank() == 0) + return targetClass.op->emitError("compact batch demand classification requires static ranked fragment metadata") + << " resultIndex=" << resultIndex << " fragmentType=" << fragmentTypes[resultIndex]; + + RunOutputDemand demand; + demand.resultIndex = resultIndex; + demand.originalOutput = firstOriginalOutputs[resultIndex]; + demand.fragmentType = fragmentType; + + for (ClassId destinationClass : destinationClasses) + demand.actions.push_back(TensorDemandAction { + .kind = TensorDemandActionKind::DestinationFanout, + .destinationClass = destinationClass, + .barrierReason = std::nullopt}); + + if (hasMaterializationRunResultSameClassConsumer(state, targetClass.id, run, resultIndex)) + demand.actions.push_back(TensorDemandAction { + .kind = TensorDemandActionKind::SameClassIndexedFragment, + .destinationClass = std::nullopt, + .barrierReason = std::nullopt}); + + if (!hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) + return demand; + + Value originalOutput = demand.originalOutput; + if (!isTerminalHostBatchOutput(originalOutput, state.oldComputeOps)) { + demand.actions.push_back(TensorDemandAction { + .kind = TensorDemandActionKind::WholeTensorBarrier, + .destinationClass = std::nullopt, + .barrierReason = WholeTensorBarrierReason::FunctionReturnWithoutBlueprint}); + return demand; + } + + auto outputType = dyn_cast(originalOutput.getType()); + if (!outputType || !outputType.hasStaticShape()) + return targetClass.op->emitError("failed to classify compact batch output demand") + << " reason=terminal blueprint publication requires static ranked output metadata" + << " producerOp='" << sourceBatch->getName() << "' resultIndex=" << resultIndex + << " sourceClass=" << targetClass.id << " valueType=" << originalOutput.getType(); + + if (failed(getBatchResultProjectionInsert(sourceBatch, resultIndex))) + return targetClass.op->emitError("failed to classify compact batch output demand") + << " reason=terminal blueprint publication is missing projection metadata" + << " producerOp='" << sourceBatch->getName() << "' resultIndex=" << resultIndex + << " sourceClass=" << targetClass.id << " valueType=" << originalOutput.getType(); + + demand.actions.push_back(TensorDemandAction { + .kind = TensorDemandActionKind::TerminalBlueprintPublication, + .destinationClass = std::nullopt, + .barrierReason = std::nullopt}); + return demand; +} + +bool hasWholeTensorBarrier(const RunOutputDemand& demand) { + return llvm::any_of(demand.actions, [](const TensorDemandAction& action) { + return action.kind == TensorDemandActionKind::WholeTensorBarrier; + }); +} + +FailureOr> +tryBuildCompactRunPlan(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + ArrayRef groups) { + if (run.size() < 2 || run.front().peers.empty()) + return std::optional {}; + + CompactRunPlan plan; + for (const OutputDestinationGroup& group : groups) { + for (size_t resultIndex : group.resultIndices) { + FailureOr demand = + classifyRunOutputDemand(state, targetClass, run, group.destinationClasses, resultIndex); + if (failed(demand)) + return failure(); + if (hasWholeTensorBarrier(*demand)) + return std::optional {}; + if (!demand->actions.empty()) + plan.outputs.push_back(std::move(*demand)); + } + } + + if (plan.outputs.empty()) + return std::optional {}; + return std::optional {std::move(plan)}; +} + void markMaterializationRunSlots(MaterializerState& state, ClassId classId, SlotId startSlot, @@ -7190,41 +7629,6 @@ bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, Cla return false; } -bool canCompactBatchClassRun(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run) { - if (run.size() < 2) - return false; - if (run.front().peers.empty()) - return false; - - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); - - for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { - (void) ignored; - for (const MaterializationRunSlot& slot : run) { - if (slot.peers.empty()) - return false; - - for (const ComputeInstance& peer : slot.peers) { - ArrayRef peerOutputs = getComputeInstanceOutputValuesCached(state, peer); - if (resultIndex >= peerOutputs.size()) - return false; - - Value originalOutput = peerOutputs[resultIndex]; - if (hasLiveExternalUseCached(state, originalOutput)) - return false; - - ProducerKey key {peer, resultIndex}; - if (hasSameClassConsumer(state, key, targetClass.id)) - return false; - } - } - } - - return true; -} - Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { auto batch = cast(targetClass.op); auto laneArg = batch.getLaneArgument(); @@ -7267,15 +7671,16 @@ Value createBatchClassRunSourceLane(MaterializerState& state, LogicalResult buildBatchRunSendPlans(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef run, - const OutputDestinationGroup& group, + const CompactRunPlan& compactPlan, SmallVectorImpl& plans) { assert(sourceClass.isBatch && "batch run send planning expects a materialized batch source"); - for (size_t resultIndex : group.resultIndices) { - for (ClassId destinationClass : group.destinationClasses) { - if (destinationClass == sourceClass.id) - return sourceClass.op->emitError("batch-target run compaction cannot handle same-class consumers"); - + for (const RunOutputDemand& output : compactPlan.outputs) { + for (const TensorDemandAction& action : output.actions) { + if (action.kind != TensorDemandActionKind::DestinationFanout) + continue; + assert(action.destinationClass && "destination fanout action must carry a destination class"); + ClassId destinationClass = *action.destinationClass; MaterializedClass& targetClass = state.classes[destinationClass]; if (targetClass.isBatch && targetClass.cpus.size() != sourceClass.cpus.size()) @@ -7283,7 +7688,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, "cannot compact batch run communication between batch classes of different sizes"); BatchRunSendPlan plan; - plan.resultIndex = resultIndex; + plan.resultIndex = output.resultIndex; plan.destinationClass = destinationClass; size_t messageCount = run.size() * sourceClass.cpus.size(); @@ -7401,6 +7806,29 @@ LogicalResult recordIndexedBatchRunReceives(MaterializerState& state, return success(); } +LogicalResult recordLocalIndexedBatchRunValue(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + size_t resultIndex, + Value packed, + RankedTensorType fragmentType) { + IndexedBatchRunValue indexedRun; + indexedRun.targetClass = targetClass.id; + indexedRun.sourceOp = run.front().peers.front().op; + indexedRun.resultIndex = resultIndex; + indexedRun.packed = packed; + indexedRun.fragmentType = fragmentType; + indexedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot indexedSlot; + indexedSlot.keys = getMaterializationRunSlotOutputKeys(slot, resultIndex); + indexedRun.slots.push_back(std::move(indexedSlot)); + } + + state.availableValues.recordIndexedBatchRun(std::move(indexedRun)); + return success(); +} + LogicalResult appendBatchRunReceives(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef run, @@ -7414,77 +7842,188 @@ LogicalResult appendBatchRunReceives(MaterializerState& state, return recordIndexedBatchRunReceives(state, run, plan, fragmentType); } +FailureOr> materializeCompactBatchOutputGroupLoop(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const CompactRunPlan& plan) { + assert(targetClass.isBatch && "compact batch output loop expects a batch target class"); + assert(!run.empty() && "expected non-empty compact batch run"); + assert(!run.front().peers.empty() && "expected non-empty compact batch run slot"); + + auto sourceBatch = dyn_cast(getMaterializationRunSourceOp(run)); + if (!sourceBatch) + return failure(); + + Location loc = getMaterializationRunLoc(run); + + SmallVector initValues; + initValues.reserve(plan.outputs.size()); + for (const RunOutputDemand& output : plan.outputs) { + FailureOr packedType = getPackedRunTensorType(output.fragmentType, run.size()); + if (failed(packedType)) + return sourceBatch.emitOpError("cannot compact batch run for non-static ranked output") + << " resultIndex=" << output.resultIndex; + initValues.push_back( + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange(initValues), + [&](OpBuilder&, Location, Value slotIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + SmallVector resultIndices; + resultIndices.reserve(plan.outputs.size()); + for (const RunOutputDemand& output : plan.outputs) + resultIndices.push_back(output.resultIndex); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + getScheduledChunkForLogicalInstance(state, run.front().peers.front()), + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); + if (failed(produced)) + return failure(); + + yielded.reserve(produced->size()); + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = dyn_cast(output.getType()); + if (!fragmentType || !fragmentType.hasStaticShape()) + return failure(); + Value firstOffset = scaleIndexByDim0Size( + state, targetClass.op, slotIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, iterArgs[outputIndex], firstOffset)); + } + return success(); + }); + if (failed(loop)) + return failure(); + + SmallVector results; + results.reserve(loop->results.size()); + for (Value result : loop->results) + results.push_back(result); + return results; +} + +LogicalResult emitPackedBatchRunSends(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const CompactRunPlan& plan, + ArrayRef packedOutputs, + Location loc) { + SmallVector sendPlans; + if (failed(buildBatchRunSendPlans(state, targetClass, run, plan, sendPlans))) + return failure(); + if (sendPlans.empty()) + return success(); + + DenseMap packedOutputIndexByResult; + for (auto [packedIndex, output] : llvm::enumerate(plan.outputs)) + packedOutputIndexByResult[output.resultIndex] = packedIndex; + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + for (const BatchRunSendPlan& sendPlan : sendPlans) { + auto packedIt = packedOutputIndexByResult.find(sendPlan.resultIndex); + if (packedIt == packedOutputIndexByResult.end()) + return failure(); + + const RunOutputDemand& output = plan.outputs[packedIt->second]; + Value fragment = getPackedSliceForDynamicRunIndex( + state, targetClass.op, packedOutputs[packedIt->second], output.fragmentType, slotIndex, loc); + appendBatchRunSend(state, targetClass, fragment, sendPlan, flatIndex, loc); + } + return success(); + }); + if (failed(loop)) + return failure(); + + for (const BatchRunSendPlan& sendPlan : sendPlans) { + auto packedIt = packedOutputIndexByResult.find(sendPlan.resultIndex); + if (packedIt == packedOutputIndexByResult.end()) + return targetClass.op->emitError("missing packed output for compact batch run send plan"); + if (failed(appendBatchRunReceives( + state, targetClass, run, sendPlan, plan.outputs[packedIt->second].fragmentType, loc))) + return failure(); + } + + return success(); +} + LogicalResult materializeBatchClassRun(MaterializerState& state, MaterializedClass& targetClass, SlotId startSlot, - ArrayRef run) { + ArrayRef run, + const CompactRunPlan& plan) { assert(targetClass.isBatch && "batch-target run materialization expects a materialized batch class"); assert(!run.empty() && "expected non-empty batch-target run"); - if (!canCompactBatchClassRun(state, targetClass, run)) - return failure(); - markMaterializationRunSlots(state, targetClass.id, startSlot, run); - SmallVector groups = groupBatchRunOutputsByDestination(state, run); - auto sourceBatch = cast(run.front().peers.front().op); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = sourceBatch.getLoc(); - for (const OutputDestinationGroup& group : groups) { - SmallVector sendPlans; - if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) - return failure(); + FailureOr> packedOutputs = materializeCompactBatchOutputGroupLoop(state, targetClass, run, plan); + if (failed(packedOutputs)) + return failure(); - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + for (auto [packedIndex, output] : llvm::enumerate(plan.outputs)) { + SmallVector keys = getMaterializationRunOutputKeys(run, output.resultIndex); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { - Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - getScheduledChunkForLogicalInstance(state, run.front().peers.front()), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); - if (failed(produced)) + for (const TensorDemandAction& action : output.actions) { + switch (action.kind) { + case TensorDemandActionKind::DestinationFanout: + break; + case TensorDemandActionKind::SameClassIndexedFragment: + if (failed(recordLocalIndexedBatchRunValue( + state, targetClass, run, output.resultIndex, (*packedOutputs)[packedIndex], output.fragmentType))) return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); - if (resultIt == group.resultIndices.end()) - return failure(); - - size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); - appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); - } - return success(); - }); - if (failed(loop)) - return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) - return failure(); - - if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) - return failure(); + break; + case TensorDemandActionKind::TerminalBlueprintPublication: { + FailureOr recordedProjectedHostFragments = recordProjectedScalarHostFragmentsFromPackedValue( + state, targetClass, keys, (*packedOutputs)[packedIndex], output.originalOutput, loc); + if (failed(recordedProjectedHostFragments)) + return failure(); + if (!*recordedProjectedHostFragments) + return sourceBatch.emitOpError("compact batch blueprint publication requires explicit fragment assembly metadata") + << " resultIndex=" << output.resultIndex; + break; + } + case TensorDemandActionKind::WholeTensorBarrier: + return sourceBatch.emitOpError("compact batch materialization reached a whole-tensor barrier unexpectedly") + << " resultIndex=" << output.resultIndex << " reason=" + << describeWholeTensorBarrierReason(*action.barrierReason); + } } } + if (failed(emitPackedBatchRunSends(state, targetClass, run, plan, *packedOutputs, loc))) + return failure(); + return success(); } @@ -7516,8 +8055,15 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, if (!targetClass.isBatch) return materializeScalarBatchRun(state, targetClass, startLogicalSlot, *run); - if (succeeded(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run))) + SmallVector groups = groupBatchRunOutputsByDestination(state, *run); + FailureOr> plan = tryBuildCompactRunPlan(state, targetClass, *run, groups); + if (failed(plan)) + return failure(); + if (*plan) { + if (failed(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run, **plan))) + return failure(); return success(); + } } }