From 78e97f9fd860d9b53830126c330fb2edf19fdd2e Mon Sep 17 00:00:00 2001 From: ilgeco Date: Fri, 26 Jun 2026 17:45:27 +0200 Subject: [PATCH] Bose --- .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 32 + .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 8 + .../ONNXToSpatial/LowerSpatialPlansPass.cpp | 66 +- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 12 +- .../ONNXToSpatial/ONNXToSpatialVerifier.cpp | 18 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 13 +- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 5 +- .../SpatialLayoutPlanningPass.cpp | 12 +- .../BatchCoreLoweringPatterns.cpp | 84 +- src/PIM/Conversion/SpatialToPim/Common.cpp | 14 +- src/PIM/Conversion/SpatialToPim/Common.hpp | 4 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 50 +- src/PIM/Conversion/SpatialToPim/Patterns.cpp | 8 +- .../SpatialToPim/ReturnPathNormalization.cpp | 72 +- .../SpatialToPim/SpatialToPimPass.cpp | 2 +- src/PIM/Dialect/Spatial/Spatial.td | 5 +- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 133 + src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 30 +- .../MaterializeMergeSchedule.cpp | 242 +- .../MaterializeMergeSchedule.cpp.bkk | 9510 ----------------- .../MaterializeMergeSchedule.cpp.orig | 7548 ------------- .../MaterializeMergeSchedule.cpp.rej | 128 - .../MergeComputeNodesPass.cpp | 6 +- 23 files changed, 513 insertions(+), 17489 deletions(-) delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index d97da5a..d4d8987 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -163,6 +163,38 @@ Value extractAxisSlice( .getResult(); } +Value extractStaticSliceOrIdentity(RewriterBase& rewriter, + Location loc, + Value source, + RankedTensorType resultType, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto sourceType = cast(source.getType()); + size_t rank = static_cast(sourceType.getRank()); + + bool isIdentitySlice = + sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank + && strides.size() == rank; + if (isIdentitySlice) { + ArrayRef sourceShape = sourceType.getShape(); + for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) { + std::optional staticOffset = mlir::getConstantIntValue(offset); + std::optional staticSize = mlir::getConstantIntValue(size); + std::optional staticStride = mlir::getConstantIntValue(stride); + if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) { + isIdentitySlice = false; + break; + } + } + } + + if (isIdentitySlice) + return source; + + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); +} + Value insertStaticSlice( PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef offsets) { auto sourceType = cast(source.getType()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index b803969..28cc6d0 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -105,6 +105,14 @@ llvm::DenseMap> sliceVectorPerCrossbarPer mlir::Value extractAxisSlice( mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); +mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value source, + mlir::RankedTensorType resultType, + llvm::ArrayRef offsets, + llvm::ArrayRef sizes, + llvm::ArrayRef strides); + mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, diff --git a/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp b/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp index be74eb2..e933cb9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp @@ -44,17 +44,17 @@ static FailureOr getRowStripValue(llvm::DenseMapsecond; } -static FailureOr buildRowStripValue(spatial::SpatReconciliatorOp reconciliator, +static FailureOr buildRowStripValue(spatial::SpatBlueprintOp blueprint, Value physicalValue) { - auto logicalType = dyn_cast(reconciliator.getOutput().getType()); + auto logicalType = dyn_cast(blueprint.getOutput().getType()); if (!logicalType) - return reconciliator.emitOpError("requires ranked logical output type"), failure(); + return blueprint.emitOpError("requires ranked logical output type"), failure(); RowStripPhysicalValue value; value.physicalValue = physicalValue; value.logicalType = logicalType; - value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end()); - value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end()); - value.indexMap = reconciliator.getIndexMap().str(); + value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end()); + value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end()); + value.indexMap = blueprint.getIndexMap().str(); return value; } @@ -175,7 +175,7 @@ struct LowerSpatialPlansPass final : PassWrapper bool { if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc))) return true; - moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage; + moduleOp.emitError() << "logical Spatial graph verification failed " << stage; signalPassFailure(); return false; }; @@ -185,11 +185,11 @@ struct LowerSpatialPlansPass final : PassWrapper(&op)) { FailureOr rowStripInput = getRowStripValue(rowStripValues, planOp.getInput()); - auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { - auto reconciliator = dyn_cast(user); - return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout; + auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { + auto blueprint = dyn_cast(user); + return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout; }); - if (rowStripReconciliator != planOp.getResult().getUsers().end()) { + if (rowStripBlueprint != planOp.getResult().getUsers().end()) { rewriter.setInsertionPoint(planOp); FailureOr lowered = lowerSelectedConv2DPlan( planOp, @@ -201,15 +201,15 @@ struct LowerSpatialPlansPass final : PassWrapper(*rowStripReconciliator); - FailureOr rowStripValue = buildRowStripValue(reconciliator, *lowered); + auto blueprint = cast(*rowStripBlueprint); + FailureOr rowStripValue = buildRowStripValue(blueprint, *lowered); if (failed(rowStripValue)) { signalPassFailure(); return; } - rowStripValues[reconciliator.getResult()] = *rowStripValue; + rowStripValues[blueprint.getResult()] = *rowStripValue; eraseAfterLowering.insert(planOp); - eraseAfterLowering.insert(reconciliator); + eraseAfterLowering.insert(blueprint); continue; } rewriter.setInsertionPoint(planOp); @@ -226,12 +226,12 @@ struct LowerSpatialPlansPass final : PassWrapper(&op)) { if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) { - auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { - auto reconciliator = dyn_cast(user); - return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout; + auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { + auto blueprint = dyn_cast(user); + return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout; }); - if (outputReconciliator == planOp.getResult().getUsers().end()) { - planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result"); + if (outputBlueprint == planOp.getResult().getUsers().end()) { + planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result"); signalPassFailure(); return; } @@ -244,15 +244,15 @@ struct LowerSpatialPlansPass final : PassWrapper(*outputReconciliator); - FailureOr output = buildRowStripValue(reconciliator, *lowered); + auto blueprint = cast(*outputBlueprint); + FailureOr output = buildRowStripValue(blueprint, *lowered); if (failed(output)) { signalPassFailure(); return; } - rowStripValues[reconciliator.getResult()] = *output; + rowStripValues[blueprint.getResult()] = *output; eraseAfterLowering.insert(planOp); - eraseAfterLowering.insert(reconciliator); + eraseAfterLowering.insert(blueprint); continue; } @@ -279,7 +279,7 @@ struct LowerSpatialPlansPass final : PassWrapper rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput()); if (failed(rowStripValue)) { - materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization"); + materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization"); signalPassFailure(); return; } @@ -293,18 +293,18 @@ struct LowerSpatialPlansPass final : PassWrapper(&op)) { - if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) { - rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput()); + if (auto blueprintOp = dyn_cast(&op)) { + if (blueprintOp.getPhysicalLayout() == kDenseLayout) { + rewriter.replaceOp(blueprintOp, blueprintOp.getInput()); continue; } - if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) { - reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet"); + if (blueprintOp.getPhysicalLayout() != kRowStripLayout) { + blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet"); signalPassFailure(); return; } - if (!eraseAfterLowering.contains(reconciliatorOp)) { - reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans"); + if (!eraseAfterLowering.contains(blueprintOp)) { + blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans"); signalPassFailure(); return; } @@ -385,7 +385,7 @@ struct LowerSpatialPlansPass final : PassWrapper(op) || op->getDialect()->getNamespace() == "onnx") { op->emitOpError("operation must not remain after LowerSpatialPlans"); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index de87200..c2ab151 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -46,9 +46,9 @@ static void populateEmptyFunction(func::FuncOp funcOp) { SmallVector computeBatches(funcOp.getOps()); SmallVector convPlans(funcOp.getOps()); SmallVector reluPlans(funcOp.getOps()); - SmallVector reconciliators(funcOp.getOps()); + SmallVector blueprints(funcOp.getOps()); SmallVector materializers(funcOp.getOps()); - if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty() + if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !blueprints.empty() || !materializers.empty()) { return; } @@ -160,7 +160,7 @@ void ONNXToSpatialPass::runOnOperation() { } if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { - moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion"); + moduleOp.emitError("logical Spatial graph verification failed after ONNX conversion"); signalPassFailure(); return; } @@ -181,7 +181,7 @@ void ONNXToSpatialPass::runOnOperation() { annotateWeightsConstants(*entryFunc); if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { - moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation"); + moduleOp.emitError("logical Spatial graph verification failed after weight annotation"); signalPassFailure(); return; } @@ -199,7 +199,7 @@ void ONNXToSpatialPass::runOnOperation() { [](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); }); if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { - moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites"); + moduleOp.emitError("logical Spatial graph verification failed before post rewrites"); signalPassFailure(); return; } @@ -214,7 +214,7 @@ void ONNXToSpatialPass::runOnOperation() { populateEmptyFunction(*entryFunc); if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { - moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial"); + moduleOp.emitError("logical Spatial graph verification failed after ONNX-to-Spatial"); signalPassFailure(); return; } diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp index e05bdff..bbb6604 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp @@ -15,7 +15,7 @@ namespace onnx_mlir { namespace { -constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK"; +constexpr StringLiteral kPhaseMarker = "phase-check"; void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { func.walk([&](Operation* op) { @@ -114,14 +114,14 @@ void verifyScheduledInputs(ComputeOpTy compute, } template -void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute, +void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute, pim::CappedDiagnosticReporter& diagnostics) { - compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) { - std::optional mode = reconciliator.getMode(); + compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) { + std::optional mode = blueprint.getMode(); if (!mode || *mode != "fragment_assembly") return; - diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) { - illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization"); + diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) { + illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization"); }); }); } @@ -133,7 +133,7 @@ void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter spatial::SpatGraphComputeBatch, spatial::SpatConv2DPlanOp, spatial::SpatReluPlanOp, - spatial::SpatReconciliatorOp, + spatial::SpatBlueprintOp, spatial::SpatMaterializeLayoutOp>(&op)) { continue; } @@ -203,11 +203,11 @@ LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) { verifyScheduledTopLevelOps(funcOp, diagnostics); for (auto compute : funcOp.getOps()) { verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics); - verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics); + verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics); } for (auto batch : funcOp.getOps()) { verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics); - verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics); + verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics); } if (failed(verifyNoComputeBodyCaptures(funcOp))) return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 2709f1e..23080a4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -2242,8 +2242,8 @@ static FailureOr rewriteInputKTiledConv(const ConvLoweringState& state, rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides); SmallVector bOffsets {kOffset, rewriter.getIndexAttr(0)}; SmallVector bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)}; - Value bTile = tensor::ExtractSliceOp::create( - rewriter, reduceLoc, weightTileType, weightArg, bOffsets, bSizes, unitStrides); + Value bTile = extractStaticSliceOrIdentity( + rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides); Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult(); reduceYielded.push_back( spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult()); @@ -2912,8 +2912,13 @@ static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2)); SmallVector bOffsets {kOffset, rewriter.getIndexAttr(0)}; SmallVector bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)}; - Value bTile = tensor::ExtractSliceOp::create( - rewriter, reduceLoc, paddedWeightTileType, args.weights.front(), bOffsets, bSizes, getUnitStrides(rewriter, 2)); + Value bTile = extractStaticSliceOrIdentity(rewriter, + reduceLoc, + args.weights.front(), + paddedWeightTileType, + bOffsets, + bSizes, + getUnitStrides(rewriter, 2)); Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult(); reduceYielded.push_back( spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index e6914d5..5b8a7c4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -285,9 +285,8 @@ static FailureOr createVmmBatch(Value a, SmallVector bSizes {rewriter.getIndexAttr(crossbarSize.getValue()), rewriter.getIndexAttr(crossbarSize.getValue())}; SmallVector unitStrides = getUnitStrides(rewriter, 2); - Value bTile = - tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides) - .getResult(); + Value bTile = extractStaticSliceOrIdentity( + rewriter, loc, args.weights.front(), bTileType, bOffsets, bSizes, unitStrides); Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; diff --git a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp index 8845f76..815ef49 100644 --- a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp @@ -90,10 +90,10 @@ static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan, return SelectedLayout::NchwRowStrip; } -static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) { +static spatial::SpatBlueprintOp insertRowStripBlueprint(IRRewriter& rewriter, Value value) { auto outputType = cast(value.getType()); auto [offsets, sizes] = buildRowStripMetadata(outputType); - return spatial::SpatReconciliatorOp::create(rewriter, + return spatial::SpatBlueprintOp::create(rewriter, value.getLoc(), outputType, value, @@ -189,12 +189,12 @@ struct SpatialLayoutPlanningPass final : PassWrapper(use.getOwner()); - if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType()) + auto blueprint = dyn_cast(use.getOwner()); + if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType()) return failure(); - std::optional mode = reconciliator.getMode(); - std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); - std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); - std::optional> stridesAttr = reconciliator.getFragmentStrides(); + std::optional mode = blueprint.getMode(); + std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); + std::optional> stridesAttr = blueprint.getFragmentStrides(); if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) return failure(); - if (!reconciliator.getOutput().hasOneUse() || !isa(*reconciliator.getOutput().getUsers().begin())) + if (!blueprint.getOutput().hasOneUse() || !isa(*blueprint.getOutput().getUsers().begin())) return failure(); - unsigned returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber(); - auto hostResultType = dyn_cast(reconciliator.getOutput().getType()); + unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber(); + auto hostResultType = dyn_cast(blueprint.getOutput().getType()); if (!hostResultType || !hostResultType.hasStaticShape()) return failure(); ArrayRef operandIndices = *operandIndicesAttr; ArrayRef sourceOffsets = *sourceOffsetsAttr; - ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); - ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatOffsets = blueprint.getFragmentOffsets(); + ArrayRef flatSizes = blueprint.getFragmentSizes(); ArrayRef flatStrides = *stridesAttr; int64_t rank = hostResultType.getRank(); - SmallVector fragmentOperands {reconciliator.getInput()}; - llvm::append_range(fragmentOperands, reconciliator.getFragments()); - if (failed(validateFragmentAssemblyMetadata(reconciliator, + SmallVector fragmentOperands {blueprint.getInput()}; + llvm::append_range(fragmentOperands, blueprint.getFragments()); + if (failed(validateFragmentAssemblyMetadata(blueprint, rank, fragmentOperands.size(), operandIndices, @@ -379,34 +379,34 @@ static SmallVector buildFragmentOffsets(IRRewriter& rewriter, } static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, - spatial::SpatReconciliatorOp reconciliator, + spatial::SpatBlueprintOp blueprint, Value hostTarget, ArrayRef baseOffsets, IRMapping& mapper) { auto hostTargetType = dyn_cast(hostTarget.getType()); - auto resultType = dyn_cast(reconciliator.getOutput().getType()); + auto resultType = dyn_cast(blueprint.getOutput().getType()); if (!hostTargetType || !resultType || !resultType.hasStaticShape()) - return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results"); + return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results"); - std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); - std::optional> fragmentStridesAttr = reconciliator.getFragmentStrides(); + std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); + std::optional> fragmentStridesAttr = blueprint.getFragmentStrides(); if (!operandIndicesAttr || !fragmentStridesAttr) - return reconciliator.emitOpError( + return blueprint.emitOpError( "fragment assembly lowering requires explicit operand indices and unit strides"); ArrayRef operandIndices = *operandIndicesAttr; - std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); + std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); if (!sourceOffsetsAttr) - return reconciliator.emitOpError("fragment assembly lowering requires explicit source offsets"); + return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets"); ArrayRef sourceOffsets = *sourceOffsetsAttr; - ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); - ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatOffsets = blueprint.getFragmentOffsets(); + ArrayRef flatSizes = blueprint.getFragmentSizes(); ArrayRef flatStrides = *fragmentStridesAttr; int64_t rank = resultType.getRank(); - SmallVector fragmentOperands {reconciliator.getInput()}; - llvm::append_range(fragmentOperands, reconciliator.getFragments()); - if (failed(validateFragmentAssemblyMetadata(reconciliator, + SmallVector fragmentOperands {blueprint.getInput()}; + llvm::append_range(fragmentOperands, blueprint.getFragments()); + if (failed(validateFragmentAssemblyMetadata(blueprint, rank, fragmentOperands.size(), operandIndices, @@ -423,14 +423,14 @@ static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) - return reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); + return blueprint.emitOpError("fragment assembly lowering only supports unit strides"); fragmentOffsets.push_back(flatOffsets[flatIndex]); } Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]); auto sourceType = dyn_cast(source.getType()); if (!sourceType || !sourceType.hasStaticShape()) - return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands"); SmallVector fragmentShape; fragmentShape.reserve(rank); @@ -440,11 +440,11 @@ static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, Value fragment = source; if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) { FailureOr> extractOffsets = getStaticSliceOffsetsForElementOffset( - reconciliator, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice"); + blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice"); if (failed(extractOffsets)) return failure(); fragment = tensor::ExtractSliceOp::create(rewriter, - reconciliator.getLoc(), + blueprint.getLoc(), source, getStaticIndexAttrs(rewriter, *extractOffsets), getStaticIndexAttrs(rewriter, fragmentShape), @@ -452,11 +452,11 @@ static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, } hostTarget = tensor::InsertSliceOp::create(rewriter, - reconciliator.getLoc(), + blueprint.getLoc(), fragment, hostTarget, buildFragmentOffsets(rewriter, - reconciliator.getLoc(), + blueprint.getLoc(), baseOffsets, fragmentOffsets, mapper), @@ -585,13 +585,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul if (isa(op)) continue; - if (auto reconciliator = dyn_cast(op)) { - std::optional modeAttr = reconciliator.getMode(); + if (auto blueprint = dyn_cast(op)) { + std::optional modeAttr = blueprint.getMode(); if (modeAttr && *modeAttr == "fragment_assembly") { - for (Operation* user : reconciliator.getOutput().getUsers()) { + for (Operation* user : blueprint.getOutput().getUsers()) { if (!isa(user)) - return reconciliator.emitOpError( - "fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users"); + return blueprint.emitOpError( + "fragment assembly blueprint lowering expects only tensor.parallel_insert_slice users"); } continue; } @@ -653,12 +653,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc()); auto hostTargetType = cast(hostTarget.getType()); - if (auto reconciliator = - insertSlice.getSource().getDefiningOp()) { - std::optional modeAttr = reconciliator.getMode(); + if (auto blueprint = + insertSlice.getSource().getDefiningOp()) { + std::optional modeAttr = blueprint.getMode(); if (modeAttr && *modeAttr == "fragment_assembly") { FailureOr updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter, - reconciliator, + blueprint, hostTarget, insertSlice.getMixedOffsets(), mapper); diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 970063c..38ee17c 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -73,7 +73,7 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); } -LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator, +LogicalResult validateFragmentAssemblyMetadata(spatial::SpatBlueprintOp blueprint, int64_t resultRank, size_t operandCount, ArrayRef operandIndices, @@ -82,19 +82,19 @@ LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reco ArrayRef flatSizes, ArrayRef flatStrides) { if (operandIndices.size() != sourceOffsets.size()) - return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match"); + return blueprint.emitOpError("fragment assembly operand index and source offset counts must match"); if (flatOffsets.size() != flatSizes.size()) - return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths"); + return blueprint.emitOpError("fragment assembly offset and size arrays must have matching lengths"); if (flatStrides.size() != flatOffsets.size()) - return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths"); + return blueprint.emitOpError("fragment assembly stride and offset arrays must have matching lengths"); if (flatOffsets.size() != operandIndices.size() * static_cast(resultRank)) - return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment"); + return blueprint.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment"); for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) { if (operandIndex < 0 || operandIndex >= static_cast(operandCount)) - return reconciliator.emitOpError("fragment assembly operand index is out of range"); + return blueprint.emitOpError("fragment assembly operand index is out of range"); if (sourceOffsets[fragmentIndex] < 0) - return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative"); + return blueprint.emitOpError("fragment assembly source offsets must be nonnegative"); } return success(); diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index aa1c7cb..cc4a1ef 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -9,7 +9,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" namespace onnx_mlir::spatial { -class SpatReconciliatorOp; +class SpatBlueprintOp; } namespace onnx_mlir { @@ -36,7 +36,7 @@ mlir::SmallVector getOpOperandsSortedByUses(mlir::Operation* operat mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation); -mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatReconciliatorOp reconciliator, +mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatBlueprintOp blueprint, int64_t resultRank, size_t operandCount, llvm::ArrayRef operandIndices, diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index ca8c844..25521af 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -43,31 +43,31 @@ static Value createStaticHostTargetOffset(IRRewriter& rewriter, return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset); } -static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, - spatial::SpatReconciliatorOp reconciliator, +static FailureOr lowerFragmentAssemblyBlueprint(IRRewriter& rewriter, + spatial::SpatBlueprintOp blueprint, IRMapping& mapping) { - auto resultType = dyn_cast(reconciliator.getOutput().getType()); + auto resultType = dyn_cast(blueprint.getOutput().getType()); if (!resultType || !resultType.hasStaticShape()) - return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result"); + return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result"); - std::optional modeAttr = reconciliator.getMode(); - std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); - std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); - std::optional> fragmentStridesAttr = reconciliator.getFragmentStrides(); + std::optional modeAttr = blueprint.getMode(); + std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); + std::optional> fragmentStridesAttr = blueprint.getFragmentStrides(); if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr) - return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata"); + return blueprint.emitOpError("fragment assembly lowering requires explicit fragment metadata"); ArrayRef operandIndices = *operandIndicesAttr; ArrayRef sourceOffsets = *sourceOffsetsAttr; - ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); - ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatOffsets = blueprint.getFragmentOffsets(); + ArrayRef flatSizes = blueprint.getFragmentSizes(); ArrayRef flatStrides = *fragmentStridesAttr; int64_t rank = resultType.getRank(); - SmallVector fragmentOperands {reconciliator.getInput()}; - llvm::append_range(fragmentOperands, reconciliator.getFragments()); - if (failed(validateFragmentAssemblyMetadata(reconciliator, + SmallVector fragmentOperands {blueprint.getInput()}; + llvm::append_range(fragmentOperands, blueprint.getFragments()); + if (failed(validateFragmentAssemblyMetadata(blueprint, rank, fragmentOperands.size(), operandIndices, @@ -77,7 +77,7 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, flatStrides))) return failure(); - Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType); + Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType); for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; @@ -86,7 +86,7 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) - return reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); + return blueprint.emitOpError("fragment assembly lowering only supports unit strides"); fragmentOffsets.push_back(flatOffsets[flatIndex]); fragmentElements *= flatSizes[flatIndex]; } @@ -94,21 +94,21 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]); auto sourceType = dyn_cast(source.getType()); if (!sourceType || !sourceType.hasStaticShape()) - return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands"); int64_t fragmentBytes = fragmentElements * static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); auto sizeAttr = pim::getCheckedI32Attr(rewriter, - reconciliator.getOperation(), + blueprint.getOperation(), fragmentBytes, "fragment assembly host copy size"); if (failed(sizeAttr)) return failure(); - Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets); + Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets); auto deviceSourceOffsetBytes = pim::checkedMul(static_cast(sourceOffsets[fragmentIndex]), static_cast(getElementTypeSizeInBytes(sourceType.getElementType())), - reconciliator, + blueprint, "fragment assembly device source offset"); if (failed(deviceSourceOffsetBytes)) return failure(); @@ -116,7 +116,7 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, rewriter.getInsertionBlock()->getParentOp(), static_cast(*deviceSourceOffsetBytes)); currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter, - reconciliator.getLoc(), + blueprint.getLoc(), currentOutput.getType(), hostTargetOffset, deviceSourceOffset, @@ -230,13 +230,13 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule mapping.map(*weightArg, weight); } for (Operation& op : block.without_terminator()) { - if (auto reconciliator = dyn_cast(op)) { - std::optional modeAttr = reconciliator.getMode(); + if (auto blueprint = dyn_cast(op)) { + std::optional modeAttr = blueprint.getMode(); if (modeAttr && *modeAttr == "fragment_assembly") { - auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping); + auto lowered = lowerFragmentAssemblyBlueprint(rewriter, blueprint, mapping); if (failed(lowered)) return false; - mapping.map(reconciliator.getOutput(), *lowered); + mapping.map(blueprint.getOutput(), *lowered); continue; } } diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index b675ff6..3901d40 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -31,11 +31,11 @@ static SmallVector getUnitStrides(Builder& builder, int64_t ran return strides; } -struct LowerFragmentAssemblyReconciliatorPattern - : OpConversionPattern { +struct LowerFragmentAssemblyBlueprintPattern + : OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op, + LogicalResult matchAndRewrite(spatial::SpatBlueprintOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { std::optional modeAttr = op.getMode(); @@ -125,7 +125,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) { void populateCoreBodyPatterns(RewritePatternSet& patterns) { raptor::populateWithGenerated(patterns); populateTransposeLoweringPatterns(patterns); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 3a83b08..35dda7a 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -149,36 +149,36 @@ static std::optional analyzeReturnUse(Value value) { }; } -static FailureOr, 4>> +static FailureOr, 4>> analyzeTopLevelFragmentAssemblyUses(Value value) { - SmallVector, 4> uses; + SmallVector, 4> uses; for (OpOperand& use : value.getUses()) { - auto reconciliator = dyn_cast(use.getOwner()); - if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType()) + auto blueprint = dyn_cast(use.getOwner()); + if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType()) return failure(); - std::optional mode = reconciliator.getMode(); + std::optional mode = blueprint.getMode(); if (!mode || *mode != "fragment_assembly") return failure(); - if (!reconciliator.getOutput().hasOneUse() || !isa(*reconciliator.getOutput().getUsers().begin())) + if (!blueprint.getOutput().hasOneUse() || !isa(*blueprint.getOutput().getUsers().begin())) return failure(); - std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); - std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); - std::optional> stridesAttr = reconciliator.getFragmentStrides(); - auto resultType = dyn_cast(reconciliator.getOutput().getType()); + std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); + std::optional> stridesAttr = blueprint.getFragmentStrides(); + auto resultType = dyn_cast(blueprint.getOutput().getType()); if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape()) return failure(); - SmallVector fragmentOperands {reconciliator.getInput()}; - llvm::append_range(fragmentOperands, reconciliator.getFragments()); - if (failed(validateFragmentAssemblyMetadata(reconciliator, + SmallVector fragmentOperands {blueprint.getInput()}; + llvm::append_range(fragmentOperands, blueprint.getFragments()); + if (failed(validateFragmentAssemblyMetadata(blueprint, resultType.getRank(), fragmentOperands.size(), *operandIndicesAttr, *sourceOffsetsAttr, - reconciliator.getFragmentOffsets(), - reconciliator.getFragmentSizes(), + blueprint.getFragmentOffsets(), + blueprint.getFragmentSizes(), *stridesAttr))) return failure(); - uses.emplace_back(reconciliator, use.getOperandNumber()); + uses.emplace_back(blueprint, use.getOperandNumber()); } return uses; } @@ -593,7 +593,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low } } - FailureOr, 4>> fragmentAssemblyUses = + FailureOr, 4>> fragmentAssemblyUses = analyzeTopLevelFragmentAssemblyUses(producedValue); if (succeeded(fragmentAssemblyUses)) { auto sourceType = dyn_cast(storedValue.getType()); @@ -603,35 +603,35 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low } size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType()); - for (auto [reconciliator, operandNumber] : *fragmentAssemblyUses) { + for (auto [blueprint, operandNumber] : *fragmentAssemblyUses) { rewriter.setInsertionPointAfterValue(storedValue); - std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); - std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); - std::optional> stridesAttr = reconciliator.getFragmentStrides(); + std::optional> operandIndicesAttr = blueprint.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); + std::optional> stridesAttr = blueprint.getFragmentStrides(); if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) { - reconciliator.emitOpError( + blueprint.emitOpError( "fragment assembly lowering requires explicit operand, source-offset, and stride metadata"); return ReturnPathLoweringResult::Failure; } - size_t returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber(); + size_t returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber(); Value outputTensor = outputTensors[returnIndex](rewriter, loc); auto outputType = dyn_cast(outputTensor.getType()); - auto resultType = dyn_cast(reconciliator.getOutput().getType()); + auto resultType = dyn_cast(blueprint.getOutput().getType()); if (!outputType || !resultType || !resultType.hasStaticShape()) { - reconciliator.emitOpError("fragment assembly lowering requires static ranked host outputs"); + blueprint.emitOpError("fragment assembly lowering requires static ranked host outputs"); return ReturnPathLoweringResult::Failure; } ArrayRef operandIndices = *operandIndicesAttr; ArrayRef sourceOffsets = *sourceOffsetsAttr; - ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); - ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatOffsets = blueprint.getFragmentOffsets(); + ArrayRef flatSizes = blueprint.getFragmentSizes(); ArrayRef flatStrides = *stridesAttr; int64_t rank = resultType.getRank(); - if (failed(validateFragmentAssemblyMetadata(reconciliator, + if (failed(validateFragmentAssemblyMetadata(blueprint, rank, - 1 + reconciliator.getFragments().size(), + 1 + blueprint.getFragments().size(), operandIndices, sourceOffsets, flatOffsets, @@ -647,7 +647,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) { - reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); + blueprint.emitOpError("fragment assembly lowering only supports unit strides"); return ReturnPathLoweringResult::Failure; } fragmentOffsets.push_back(flatOffsets[flatIndex]); @@ -684,7 +684,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low outputTensor = pim::PimMemCopyDevToHostOp::create(rewriter, - reconciliator.getLoc(), + blueprint.getLoc(), outputTensor.getType(), getOrCreateIndexConstant(rewriter, producerOp, *hostOffset), getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset), @@ -698,7 +698,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low if (failedChunk) return ReturnPathLoweringResult::Failure; } - markOpToRemove(reconciliator.getOperation()); + markOpToRemove(blueprint.getOperation()); } return ReturnPathLoweringResult::Handled; } @@ -813,11 +813,11 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret return; } - if (auto reconciliator = dyn_cast(op)) { - std::optional mode = reconciliator.getMode(); + if (auto blueprint = dyn_cast(op)) { + std::optional mode = blueprint.getMode(); if (mode && *mode == "fragment_assembly") { - markOpToRemove(reconciliator.getOperation()); - for (Value operand : reconciliator->getOperands()) + markOpToRemove(blueprint.getOperation()); + for (Value operand : blueprint->getOperands()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); return; } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 2f1102e..d87c459 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -203,7 +203,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { func::FuncOp funcOp = *entryFunc; if (failed(verifyScheduledSpatialInvariants(funcOp))) { funcOp.emitOpError( - "RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim"); + "scheduled Spatial verification failed at the start of SpatialToPim"); signalPassFailure(); return; } diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 2bd92c0..603e626 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -232,8 +232,8 @@ def SpatReluPlanOp : SpatOp<"relu_plan", []> { let hasVerifier = 1; } -def SpatReconciliatorOp : SpatOp<"reconciliator", []> { - let summary = "Logical-to-physical layout record or explicit fragment assembly"; +def SpatBlueprintOp : SpatOp<"blueprint", []> { + let summary = "Blueprint for assembling logical tensors from published fragments"; let arguments = (ins SpatTensor:$input, @@ -256,6 +256,7 @@ def SpatReconciliatorOp : SpatOp<"reconciliator", []> { ); let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> { diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 63c9504..8d1c39f 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -32,6 +32,14 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } +static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) { + StringRef value; + if (parser.parseKeyword(&value)) + return failure(); + attr = parser.getBuilder().getStringAttr(value); + return success(); +} + static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef arguments) { printer << "("; for (auto [index, argument] : llvm::enumerate(arguments)) { @@ -466,6 +474,131 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } +void SpatBlueprintOp::print(OpAsmPrinter& printer) { + SmallVector operands {getInput()}; + llvm::append_range(operands, getFragments()); + + printer << " fragments"; + printCompressedValueList(printer, operands, ListDelimiter::Paren); + printer << " layout " << getLogicalLayout(); + printer << " physical " << getPhysicalLayout(); + printer << " offsets "; + printCompressedIntegerList(printer, getFragmentOffsets()); + printer << " sizes "; + printCompressedIntegerList(printer, getFragmentSizes()); + printer << " map " << getIndexMap(); + if (std::optional mode = getMode()) + printer << " mode " << *mode; + if (std::optional> operandIndices = getFragmentOperandIndices()) { + printer << " operandIndices "; + printCompressedIntegerList(printer, *operandIndices); + } + if (std::optional> sourceOffsets = getFragmentSourceOffsets()) { + printer << " sourceOffsets "; + printCompressedIntegerList(printer, *sourceOffsets); + } + if (std::optional> strides = getFragmentStrides()) { + printer << " strides "; + printCompressedIntegerList(printer, *strides); + } + if (std::optional conflictPolicy = getConflictPolicy()) + printer << " conflict " << *conflictPolicy; + if (std::optional coveragePolicy = getCoveragePolicy()) + printer << " coverage " << *coveragePolicy; + + printer.printOptionalAttrDict((*this)->getAttrs(), + {getLogicalLayoutAttrName().getValue(), + getPhysicalLayoutAttrName().getValue(), + getFragmentOffsetsAttrName().getValue(), + getFragmentSizesAttrName().getValue(), + getIndexMapAttrName().getValue(), + getModeAttrName().getValue(), + getFragmentOperandIndicesAttrName().getValue(), + getFragmentSourceOffsetsAttrName().getValue(), + getFragmentStridesAttrName().getValue(), + getConflictPolicyAttrName().getValue(), + getCoveragePolicyAttrName().getValue()}); + printer << " : "; + printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector operands; + SmallVector operandTypes; + Type outputType; + StringAttr logicalLayout; + StringAttr physicalLayout; + StringAttr indexMap; + StringAttr mode; + StringAttr conflictPolicy; + StringAttr coveragePolicy; + SmallVector fragmentOffsets; + SmallVector fragmentSizes; + SmallVector fragmentOperandIndices; + SmallVector fragmentSourceOffsets; + SmallVector fragmentStrides; + + if (parser.parseKeyword("fragments") + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands) + || parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout) + || parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout) + || parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets) + || parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes) + || parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("operandIndices")) + && parseCompressedIntegerList(parser, fragmentOperandIndices)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("sourceOffsets")) + && parseCompressedIntegerList(parser, fragmentSourceOffsets)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + if (operands.empty()) + return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand"); + if (operands.size() != operandTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match"); + + auto& builder = parser.getBuilder(); + result.addAttribute("logicalLayout", logicalLayout); + result.addAttribute("physicalLayout", physicalLayout); + result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets)); + result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes)); + result.addAttribute("indexMap", indexMap); + if (mode) + result.addAttribute("mode", mode); + if (!fragmentOperandIndices.empty()) + result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices)); + if (!fragmentSourceOffsets.empty()) + result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets)); + if (!fragmentStrides.empty()) + result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides)); + if (conflictPolicy) + result.addAttribute("conflictPolicy", conflictPolicy); + if (coveragePolicy) + result.addAttribute("coveragePolicy", coveragePolicy); + + if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); } ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) { return parseComputeLikeOp(parser, result); diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index d4cbbfe..67561ff 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -436,10 +436,10 @@ LogicalResult SpatReluPlanOp::verify() { return success(); } -LogicalResult SpatReconciliatorOp::verify() { +LogicalResult SpatBlueprintOp::verify() { auto modeAttr = getModeAttr(); bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly"; - if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator"))) + if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.blueprint"))) return failure(); if (!isKnownLogicalLayout(getLogicalLayout())) return emitError("requires a known logical layout"); @@ -482,10 +482,10 @@ LogicalResult SpatReconciliatorOp::verify() { if (failed(verifyBoundsOnly({}))) return failure(); if (!getFragments().empty()) - return emitError("legacy reconciliator does not accept extra fragment operands"); + return emitError("legacy blueprint does not accept extra fragment operands"); if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr()) - return emitError("legacy reconciliator does not accept fragment assembly attributes"); + return emitError("legacy blueprint does not accept fragment assembly attributes"); return success(); } @@ -493,11 +493,11 @@ LogicalResult SpatReconciliatorOp::verify() { auto operandIndicesAttr = getFragmentOperandIndicesAttr(); auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr(); if (!operandIndicesAttr) - return emitError("fragment assembly reconciliator requires fragment operand indices"); + return emitError("fragment assembly blueprint requires fragment operand indices"); if (!sourceOffsetsAttr) - return emitError("fragment assembly reconciliator requires fragment source offsets"); + return emitError("fragment assembly blueprint requires fragment source offsets"); if (!stridesAttr) - return emitError("fragment assembly reconciliator requires fragment strides"); + return emitError("fragment assembly blueprint requires fragment strides"); ArrayRef operandIndices = operandIndicesAttr.asArrayRef(); ArrayRef sourceOffsets = sourceOffsetsAttr.asArrayRef(); ArrayRef strides = stridesAttr.asArrayRef(); @@ -506,11 +506,11 @@ LogicalResult SpatReconciliatorOp::verify() { if (sourceOffsets.size() != operandIndices.size()) return emitError("fragment source offset count must match fragment operand index count"); if (!getConflictPolicyAttr() || !getCoveragePolicyAttr()) - return emitError("fragment assembly reconciliator requires conflict and coverage policies"); + return emitError("fragment assembly blueprint requires conflict and coverage policies"); if (getConflictPolicy() != "disjoint") - return emitError("fragment assembly reconciliator currently supports only conflict_policy=\"disjoint\""); + return emitError("fragment assembly blueprint currently supports only conflict_policy=\"disjoint\""); if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial") - return emitError("fragment assembly reconciliator coverage_policy must be \"complete\" or \"partial\""); + return emitError("fragment assembly blueprint coverage_policy must be \"complete\" or \"partial\""); SmallVector operands; operands.push_back(getInput()); @@ -518,7 +518,7 @@ LogicalResult SpatReconciliatorOp::verify() { int64_t operandCount = static_cast(operands.size()); int64_t fragmentCount = static_cast(operandIndices.size()); if (operandCount == 0) - return emitError("fragment assembly reconciliator requires at least one operand"); + return emitError("fragment assembly blueprint requires at least one operand"); if (static_cast(offsets.size()) != fragmentCount * rank) return emitError("fragment assembly metadata count must match operand count * result rank"); if (failed(verifyBoundsOnly(strides))) @@ -544,9 +544,9 @@ LogicalResult SpatReconciliatorOp::verify() { auto operandType = dyn_cast(operands[operandIndex].getType()); if (!operandType || !operandType.hasStaticShape()) - return emitError("fragment assembly reconciliator requires static ranked tensor operands"); + return emitError("fragment assembly blueprint requires static ranked tensor operands"); if (operandType.getRank() != rank) - return emitError("fragment assembly reconciliator requires operand/result rank match"); + return emitError("fragment assembly blueprint requires operand/result rank match"); SmallVector fragmentOffsets; SmallVector fragmentSizes; @@ -583,14 +583,14 @@ LogicalResult SpatReconciliatorOp::verify() { } } if (overlaps) - return emitError("fragment assembly reconciliator requires disjoint static slices"); + return emitError("fragment assembly blueprint requires disjoint static slices"); } slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)}); } for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) { if (fragmentCountsByOperand[static_cast(operandIndex)] == 0) - return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment"); + return emitError("fragment assembly blueprint requires every operand to contribute at least one fragment"); } if (getCoveragePolicy() == "complete") { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 45b57b8..8fe4cd1 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -30,6 +30,7 @@ #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -308,7 +309,7 @@ struct PendingProjectedHostOutputFragment { Value originalOutput; ClassId sourceClass = 0; ProducerKey producerKey; - Value publicationValue; + unsigned publicationResultIndex = 0; int64_t sourceFragmentOrdinal = 0; int64_t sourceElementOffset = 0; SmallVector offsets; @@ -1220,36 +1221,13 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali return std::get<1>(*arg); } -void refreshPendingProjectedHostOutputPublicationValues(MaterializerState& state, - Operation* oldOwner, - Operation* newOwner) { - if (!oldOwner || oldOwner == newOwner) - return; - - for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) { - auto publicationResult = dyn_cast_or_null(fragment.publicationValue); - if (!publicationResult || publicationResult.getOwner() != oldOwner) - publicationResult = OpResult(); - else - fragment.publicationValue = newOwner->getResult(publicationResult.getResultNumber()); - - if (auto originalResult = dyn_cast_or_null(fragment.originalOutput); originalResult - && originalResult.getOwner() == oldOwner) { - fragment.originalOutput = newOwner->getResult(originalResult.getResultNumber()); - } - - if (fragment.producerKey.instance.op == oldOwner) - fragment.producerKey.instance.op = newOwner; - } -} - -FailureOr appendScalarPublicationResult(MaterializerState& state, - MaterializedClass& materializedClass, - Value payload, - Location loc) { +FailureOr appendScalarPublicationResult(MaterializerState& state, + MaterializedClass& materializedClass, + Value payload, + Location loc) { auto existing = materializedClass.publicationOutputToResultIndex.find(payload); if (existing != materializedClass.publicationOutputToResultIndex.end()) - return materializedClass.op->getResult(existing->second); + return existing->second; auto compute = dyn_cast(materializedClass.op); if (!compute) @@ -1264,27 +1242,25 @@ FailureOr appendScalarPublicationResult(MaterializerState& state, if (failed(inserted)) return materializedClass.op->emitError("failed to append scalar publication result"); - Operation* oldOp = materializedClass.op; auto [result, newCompute] = *inserted; materializedClass.op = newCompute.getOperation(); materializedClass.body = &newCompute.getBody().front(); - refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op); materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber(); auto yieldOp = dyn_cast(materializedClass.body->getTerminator()); if (!yieldOp) return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result"); state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); }); - return result; + return result.getResultNumber(); } -FailureOr appendBatchPublicationResult(MaterializerState& state, - MaterializedClass& materializedClass, - Value payload, - Location loc) { +FailureOr appendBatchPublicationResult(MaterializerState& state, + MaterializedClass& materializedClass, + Value payload, + Location loc) { auto existing = materializedClass.publicationOutputToResultIndex.find(payload); if (existing != materializedClass.publicationOutputToResultIndex.end()) - return materializedClass.op->getResult(existing->second); + return existing->second; auto batch = dyn_cast(materializedClass.op); if (!batch) @@ -1305,11 +1281,9 @@ FailureOr appendBatchPublicationResult(MaterializerState& state, if (failed(inserted)) return materializedClass.op->emitError("failed to append batch publication result"); - Operation* oldOp = materializedClass.op; auto [result, outputArg, newBatch] = *inserted; materializedClass.op = newBatch.getOperation(); materializedClass.body = &newBatch.getBody().front(); - refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op); materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber(); auto inParallelOp = dyn_cast(materializedClass.body->getTerminator()); @@ -1330,7 +1304,7 @@ FailureOr appendBatchPublicationResult(MaterializerState& state, Value firstOffset = scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc); createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset); - return result; + return result.getResultNumber(); } // ----------------------------------------------------------------------------- @@ -1563,7 +1537,7 @@ void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value val void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { Block& body = *targetClass.body; diagnostic.attachNote(targetClass.op->getLoc()) - << "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName() + << "target class " << targetClass.id << " op '" << targetClass.op->getName() << "' body has " << body.getNumArguments() << " block arguments and " << std::distance(body.begin(), body.end()) << " top-level operations"; } @@ -1687,7 +1661,7 @@ FailureOr rematerializeIndexValueInClass(MaterializerState& state, if (auto blockArg = dyn_cast(value)) { InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body"); + "cannot rematerialize external block argument in materialized class body"); diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; if (Operation* owner = blockArg.getOwner()->getParentOp()) { @@ -1709,16 +1683,16 @@ FailureOr rematerializeIndexValueInClass(MaterializerState& state, if (mapperHadOriginalValue && mappedOriginalValue != value) attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value"); if (Operation* owner = blockArg.getOwner()->getParentOp()) { - attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op"); - attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain"); + attachMaterializerOperationPrintNote(diagnostic, owner, "external block argument owner op"); + attachMaterializerParentChainNote(diagnostic, owner, "external block argument owner parent chain"); } - attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "target materialized op"); attachMaterializedClassBodySummary(diagnostic, targetClass); return failure(); } InFlightDiagnostic diagnostic = - targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body"); + targetClass.op->emitError("cannot rematerialize external index value in materialized class body"); diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); @@ -1793,8 +1767,12 @@ FailureOr rematerializeTensorValueInClass(MaterializerState& state, strides.push_back(*localized); } - return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides) - .getResult(); + auto resultType = dyn_cast(extractSlice.getResult().getType()); + if (!resultType) + return anchor->emitError("expected ranked tensor extract_slice while rematerializing tensor capture"); + + return extractStaticSliceOrIdentity( + state.rewriter, anchor->getLoc(), *localizedSource, resultType, offsets, sizes, strides); } if (auto collapseShape = value.getDefiningOp()) { @@ -2108,8 +2086,10 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in if (dim0Size == 1) return index; - Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size); - return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0 * dim0Size); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, anchor); } FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, @@ -2123,8 +2103,7 @@ FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, if (dim0Size == 1) return *localizedIndex; - Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size); - return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult(); + return scaleIndexByDim0Size(state, targetClass.op, *localizedIndex, dim0Size, loc); } bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { @@ -3677,10 +3656,13 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, ValueRange {init}, [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value acc = iterArgs.front(); - Value payloadFragmentCount = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localizedMessageIndex, payloadFragmentCount).getResult(); - Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + AffineMap flatIndexMap = + AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * descriptor.layout.payloadFragmentCount + d1); + Value flatIndex = createOrFoldAffineApply( + state.rewriter, loc, flatIndexMap, ValueRange {*localizedMessageIndex, fragmentIndex}, targetClass.op); FailureOr> fragmentOffsets = buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc); @@ -5618,8 +5600,8 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat return failure(); } - FailureOr publicationResult = appendScalarPublicationResult(state, sourceClass, packed, loc); - if (failed(publicationResult)) + FailureOr publicationResultIndex = appendScalarPublicationResult(state, sourceClass, packed, loc); + if (failed(publicationResultIndex)) return failure(); int64_t fragmentElementCount = fragmentType.getNumElements(); @@ -5657,7 +5639,7 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat originalOutput, sourceClass.id, ProducerKey {peer, resultIndex}, - *publicationResult, + *publicationResultIndex, static_cast(runIndex), static_cast(runIndex) * fragmentElementCount, SmallVector(*offsets), @@ -5711,8 +5693,8 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt if (fragmentType == originalOutput.getType()) return false; - FailureOr publicationResult = appendBatchPublicationResult(state, sourceClass, packed, loc); - if (failed(publicationResult)) + FailureOr publicationResultIndex = appendBatchPublicationResult(state, sourceClass, packed, loc); + if (failed(publicationResultIndex)) return failure(); if (packedType != fragmentType) { @@ -5764,7 +5746,7 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt originalOutput, sourceClass.id, key, - *publicationResult, + *publicationResultIndex, static_cast(fragmentIndex), static_cast(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload, SmallVector(*offsets), @@ -5787,18 +5769,26 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { SmallVector outputs; outputs.reserve(byOutput.size()); - for (const auto& entry : byOutput) - outputs.push_back(entry.first); - llvm::sort(outputs, [](Value lhs, Value rhs) { - return reinterpret_cast(lhs.getAsOpaquePointer()) - < reinterpret_cast(rhs.getAsOpaquePointer()); - }); auto returnOp = dyn_cast(state.func.getBody().front().getTerminator()); if (!returnOp) return state.func.emitError("expected func.return terminator while finalizing projected host output fragments"); + DenseSet seenOutputs; + for (Value returned : returnOp.getOperands()) { + if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second) + continue; + outputs.push_back(returned); + } + if (outputs.size() != byOutput.size()) + return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs"); + for (Value originalOutput : outputs) { + if (isa_and_present(originalOutput.getDefiningOp())) { + return state.func.emitError( + "projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result"); + } + auto resultType = dyn_cast(originalOutput.getType()); if (!resultType || !resultType.hasStaticShape()) return state.func.emitError("projected host output must have static ranked tensor type"); @@ -5806,13 +5796,12 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { SmallVector& fragments = byOutput[originalOutput]; llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, const PendingProjectedHostOutputFragment* rhs) { - if (lhs->publicationValue != rhs->publicationValue) - return reinterpret_cast(lhs->publicationValue.getAsOpaquePointer()) - < reinterpret_cast(rhs->publicationValue.getAsOpaquePointer()); - if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) - return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; if (lhs->sourceClass != rhs->sourceClass) return lhs->sourceClass < rhs->sourceClass; + if (lhs->publicationResultIndex != rhs->publicationResultIndex) + return lhs->publicationResultIndex < rhs->publicationResultIndex; + if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) + return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; return std::lexicographical_compare(lhs->offsets.begin(), lhs->offsets.end(), rhs->offsets.begin(), @@ -5821,7 +5810,7 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { state.rewriter.setInsertionPoint(returnOp); Location loc = fragments.front()->loc; - SmallVector reconciliatorOperands; + SmallVector blueprintOperands; SmallVector fragmentOperandIndices; SmallVector fragmentSourceOffsets; SmallVector flatOffsets; @@ -5830,12 +5819,23 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { DenseMap operandIndicesByValue; for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { - Value operand = fragmentRecord->publicationValue; + if (fragmentRecord->sourceClass >= state.classes.size()) + return state.func.emitError("projected host output fragment references an invalid source class"); + + MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; + if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) { + return sourceClass.op->emitError("projected host output fragment references an invalid publication result") + << " sourceClass=" << sourceClass.id + << " resultIndex=" << fragmentRecord->publicationResultIndex + << " resultCount=" << sourceClass.op->getNumResults(); + } + + Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex); auto [operandIt, inserted] = - operandIndicesByValue.try_emplace(operand, static_cast(reconciliatorOperands.size())); + operandIndicesByValue.try_emplace(operand, static_cast(blueprintOperands.size())); if (inserted) - reconciliatorOperands.push_back(operand); + blueprintOperands.push_back(operand); fragmentOperandIndices.push_back(operandIt->second); fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset); llvm::append_range(flatOffsets, fragmentRecord->offsets); @@ -5847,12 +5847,12 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { return state.func.emitError("projected host output assembly requires static ranked tensor operands"); } - if (reconciliatorOperands.empty()) + if (blueprintOperands.empty()) return state.func.emitError("missing projected host output fragments"); - Value input = reconciliatorOperands.front(); - ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front(); - auto reconciliator = spatial::SpatReconciliatorOp::create( + Value input = blueprintOperands.front(); + ValueRange extraFragments = ValueRange(blueprintOperands).drop_front(); + auto blueprint = spatial::SpatBlueprintOp::create( state.rewriter, loc, resultType, @@ -5870,7 +5870,7 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { state.rewriter.getStringAttr("disjoint"), state.rewriter.getStringAttr("complete")); - state.hostReplacements[originalOutput] = reconciliator.getOutput(); + state.hostReplacements[originalOutput] = blueprint.getOutput(); } return success(); @@ -6284,6 +6284,32 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, mapper.map(operand, *localized); } + if (auto extract = dyn_cast(&op)) { + auto remapFoldResult = [&](OpFoldResult value) -> OpFoldResult { + if (auto mappedValue = dyn_cast_if_present(value)) + return mapper.lookupOrDefault(mappedValue); + return value; + }; + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(extract.getMixedOffsets().size()); + sizes.reserve(extract.getMixedSizes().size()); + strides.reserve(extract.getMixedStrides().size()); + + llvm::append_range(offsets, llvm::map_range(extract.getMixedOffsets(), remapFoldResult)); + llvm::append_range(sizes, llvm::map_range(extract.getMixedSizes(), remapFoldResult)); + llvm::append_range(strides, llvm::map_range(extract.getMixedStrides(), remapFoldResult)); + + auto resultType = cast(extract.getType()); + Value localizedSource = mapper.lookupOrDefault(extract.getSource()); + Value localizedExtract = extractStaticSliceOrIdentity( + state.rewriter, extract.getLoc(), localizedSource, resultType, offsets, sizes, strides); + mapper.map(extract.getResult(), localizedExtract); + continue; + } + Operation* cloned = state.rewriter.clone(op, mapper); if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) return failure(); @@ -6350,18 +6376,20 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state if (failed(localizedIv)) return failure(); Value iv = *localizedIv; - Value lowerBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); - Value tripCount = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]); + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineMap normalizedMap = + AffineMap::get(/*dimCount=*/1, + /*symbolCount=*/0, + (d0 - replacement.layout.loopLowerBounds[index]).floorDiv(replacement.layout.loopSteps[index])); + Value normalized = + createOrFoldAffineApply(state.rewriter, extract.getLoc(), normalizedMap, ValueRange {iv}, targetClass.op); - Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult(); - if (replacement.layout.loopSteps[index] != 1) - normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult(); - linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult(); - linearizedIndex = - arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult(); + AffineExpr d1 = getAffineDimExpr(1, context); + AffineMap linearizedMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.loopTripCounts[index] + d1); + linearizedIndex = createOrFoldAffineApply( + state.rewriter, extract.getLoc(), linearizedMap, ValueRange {linearizedIndex, normalized}, targetClass.op); } return linearizedIndex; }; @@ -6386,12 +6414,16 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state if (failed(localProjectionSlotIndex)) return failure(); - Value fragmentsPerLogicalSlot = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); - Value base = - arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot) - .getResult(); - return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + AffineMap packedIndexMap = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.fragmentsPerLogicalSlot + d1); + return createOrFoldAffineApply(state.rewriter, + extract.getLoc(), + packedIndexMap, + ValueRange {*localProjectionSlotIndex, intraSlotFragmentIndex}, + targetClass.op); }; FailureOr packedFragmentIndex = computeProjectedPayloadFragmentIndex(); @@ -6445,18 +6477,18 @@ LogicalResult localizeCapturesInOperationTree(MaterializerState& state, localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); if (failed(localized)) { InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand"); + "failed to localize cloned scheduled-body operand"); diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; - attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR"); - attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands"); - attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain"); + attachMaterializerOperationPrintNote(diagnostic, nestedOp, "offending nested operation IR"); + attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "offending nested operation operands"); + attachMaterializerParentChainNote(diagnostic, nestedOp, "offending nested operation parent chain"); attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); - attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "target materialized op"); attachMaterializedClassBodySummary(diagnostic, targetClass); return WalkResult::interrupt(); } @@ -6505,7 +6537,7 @@ LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, Materia "final scheduled body capture localization found an unsupported external non-tensor operand"); if (failed(localized)) { InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand"); + "failed to localize final scheduled-body operand"); diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk deleted file mode 100644 index 82dbf03..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk +++ /dev/null @@ -1,9510 +0,0 @@ -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/RegionUtils.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include -#include -#include - -#include "MaterializeMergeSchedule.hpp" -#include "Scheduling/ComputeInstanceUtils.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { -namespace spatial { -namespace { - -using CpuId = size_t; -using ClassId = size_t; -using SlotId = size_t; - -static FailureOr getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) { - return pim::checkedI32(static_cast(cpu), anchor, fieldName); -} - -static FailureOr> -getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) { - SmallVector coreIds; - coreIds.reserve(cpus.size()); - for (CpuId cpu : cpus) { - auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); - if (failed(checkedCoreId)) - return failure(); - coreIds.push_back(*checkedCoreId); - } - return coreIds; -} - -struct MessageVector { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - size_t size() const { return channelIds.size(); } - bool empty() const { return channelIds.empty(); } - - LogicalResult verify(Operation* anchor) const { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) - return anchor->emitError("message metadata is inconsistent"); - return success(); - } - - void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { - channelIds.push_back(channelId); - sourceCoreIds.push_back(sourceCoreId); - targetCoreIds.push_back(targetCoreId); - } - - void append(ArrayRef channels, ArrayRef sources, ArrayRef targets) { - assert(channels.size() == sources.size() && "channel/source count mismatch"); - assert(channels.size() == targets.size() && "channel/target count mismatch"); - llvm::append_range(channelIds, channels); - llvm::append_range(sourceCoreIds, sources); - llvm::append_range(targetCoreIds, targets); - } - - MessageVector slice(size_t offset, size_t count) const { - MessageVector result; - result.append(ArrayRef(channelIds).slice(offset, count), - ArrayRef(sourceCoreIds).slice(offset, count), - ArrayRef(targetCoreIds).slice(offset, count)); - return result; - } -}; - -struct ProducerKey { - ComputeInstance instance; - size_t resultIndex = 0; - - bool operator==(const ProducerKey& other) const { - return instance == other.instance && resultIndex == other.resultIndex; - } -}; - -struct ProducerKeyInfo { - static ProducerKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; - } - - static ProducerKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ProducerKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); - } - - static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } -}; - -struct SameClassConsumerLookupKey { - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - ClassId classId = 0; - - bool operator==(const SameClassConsumerLookupKey& other) const { - return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; - } -}; - -struct SameClassConsumerLookupKeyInfo { - static SameClassConsumerLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static SameClassConsumerLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static unsigned getHashValue(const SameClassConsumerLookupKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); - } - - static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { - return lhs == rhs; - } -}; - -struct WholeBatchAssemblyLookupKey { - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - ClassId classId = 0; - - bool operator==(const WholeBatchAssemblyLookupKey& other) const { - return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; - } -}; - -struct WholeBatchAssemblyLookupKeyInfo { - static WholeBatchAssemblyLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static WholeBatchAssemblyLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); - } - - static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { - return lhs == rhs; - } -}; - -using ClassSlotKey = std::pair; - -struct MaterializedClass { - ClassId id = 0; - SmallVector cpus; - Operation* op = nullptr; - Block* body = nullptr; - bool isBatch = false; - - DenseMap cpuToLane; - SmallVector weights; - SmallVector inputs; - SmallVector hostOutputs; - DenseMap weightArgs; - DenseMap inputArgs; - DenseMap hostOutputToResultIndex; -}; - -struct PackedScalarRunSlot { - SmallVector keys; -}; - -enum class PackedScalarRunKind { - Materialized, - DeferredReceive, - DeferredLocalCompute -}; - -struct PackedScalarRunValue { - ClassId targetClass = 0; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - PackedScalarRunKind kind = PackedScalarRunKind::Materialized; - - Value packed; - - RankedTensorType fragmentType; - SmallVector slots; - MessageVector messages; -}; - -struct IndexedBatchRunValue { - ClassId targetClass = 0; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - RankedTensorType fragmentType; - SmallVector slots; - MessageVector messages; -}; - -struct LogicalSlotRange { - SlotId start = 0; - SlotId count = 0; -}; - -struct MaterializationRunSlot { - SmallVector peers; -}; - -using MaterializationRun = SmallVector; - -struct OutputDestinationGroup { - SmallVector resultIndices; - SmallVector destinationClasses; -}; - -struct BatchRunSendPlan { - size_t resultIndex = 0; - ClassId destinationClass = 0; - MessageVector messages; -}; - -struct ProjectedBatchInputKey { - Operation* consumerOp = nullptr; - unsigned inputIndex = 0; - - bool operator==(const ProjectedBatchInputKey& other) const { - return consumerOp == other.consumerOp && inputIndex == other.inputIndex; - } -}; - -struct ProjectedBatchInputKeyInfo { - static ProjectedBatchInputKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; - } - - static ProjectedBatchInputKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ProjectedBatchInputKey& key) { - return llvm::hash_combine(key.consumerOp, key.inputIndex); - } - - static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } -}; - -struct ProjectedFragmentLayout { - RankedTensorType fragmentType; - SmallVector fragmentShape; - unsigned fragmentsPerLogicalSlot = 1; - unsigned payloadFragmentCount = 1; - SmallVector loopLowerBounds; - SmallVector loopSteps; - SmallVector loopTripCounts; -}; - -struct ProjectedTransferDescriptor { - ProjectedBatchInputKey inputKey; - Operation* extractOp = nullptr; - - ProjectedFragmentLayout layout; - RankedTensorType payloadType; - SmallVector, 16> fragmentOffsets; - SmallVector, 4> fragmentOffsetsByDim; -}; - -struct ProjectedExtractReplacement { - Value payload; - ProjectedFragmentLayout layout; -}; - -struct CloneIndexingContext { - std::optional runSlotIndex; - std::optional projectionSlotIndex; -}; - -struct StaticProjectedLoopInfo { - BlockArgument iv; - int64_t lowerBound = 0; - int64_t step = 1; - int64_t tripCount = 1; -}; - -struct AffineProjectedInputSliceMatch { - tensor::ExtractSliceOp extract; - RankedTensorType sourceType; - RankedTensorType fragmentType; - SmallVector fragmentShape; - SmallVector offsets; - SmallVector loops; -}; - -struct MaterializerState; - -struct PendingProjectedHostReceiveGroup { - Value originalOutput; - ClassId ownerClassId = 0; - RankedTensorType fragmentType; - SmallVector keys; - MessageVector messages; - Location loc; -}; - -struct PendingScalarReceiveRecord { - PendingScalarReceiveRecord(ArrayRef keys, - ClassId targetClassId, - Type receiveType, - const MessageVector& messages, - Location loc) - : targetClassId(targetClassId), - receiveType(receiveType), - messages(messages), - loc(loc) { - this->keys.append(keys.begin(), keys.end()); - } - - SmallVector keys; - ClassId targetClassId = 0; - Type receiveType; - MessageVector messages; - Location loc; - bool materialized = false; - Value value; -}; - -FailureOr materializeProjectedExtractReplacement(MaterializerState& state, - MaterializedClass& targetClass, - tensor::ExtractSliceOp extract, - const ProjectedExtractReplacement& replacement, - std::optional projectionSlotIndex, - IRMapping* mapper = nullptr); -FailureOr rematerializeTensorValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - IRMapping* mapper = nullptr); -FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - std::optional producer = std::nullopt, - IRMapping* mapper = nullptr); -FailureOr localizeMaterializedClassOperand(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef tensorContext, - StringRef genericContext, - IRMapping* mapper = nullptr); -LogicalResult localizeCapturesInClonedOp(MaterializerState& state, - MaterializedClass& targetClass, - Operation& clonedOp, - IRMapping* mapper = nullptr); -bool requiresConstantProjectionSlotIndex(MaterializerState& state, - MaterializedClass& targetClass, - Operation* sourceOp); -bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane); - -class AvailableValueStore { -public: - struct ExactBatchFragmentRecord { - ProducerKey key; - Value value; - }; - - void record(ProducerKey key, ClassId classId, Value value) { - exactValues[key][classId] = value; - - auto batch = dyn_cast_or_null(key.instance.op); - if (!batch || key.instance.laneCount == 0) - return; - - WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; - SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; - for (ExactBatchFragmentRecord& record : bucket) { - if (!(record.key == key)) - continue; - record.value = value; - return; - } - bucket.push_back({key, value}); - } - - void recordPackedRun(PackedScalarRunValue run) { - size_t runIndex = packedScalarRuns.size(); - packedScalarRuns.push_back(std::move(run)); - const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; - WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; - packedRunsByProducerResultClass[lookupKey].push_back(runIndex); - } - void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } - - std::optional lookupExact(ProducerKey key, ClassId classId) const; - - std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); - IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); - - ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { - auto it = packedRunsByProducerResultClass.find(key); - if (it == packedRunsByProducerResultClass.end()) - return {}; - return it->second; - } - - ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { - auto it = exactBatchFragmentsByProducerResultClass.find(key); - if (it == exactBatchFragmentsByProducerResultClass.end()) - return {}; - return it->second; - } - - PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } - -private: - std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); - - DenseMap, ProducerKeyInfo> exactValues; - SmallVector packedScalarRuns; - SmallVector indexedBatchRuns; - DenseMap, WholeBatchAssemblyLookupKeyInfo> - exactBatchFragmentsByProducerResultClass; - DenseMap, WholeBatchAssemblyLookupKeyInfo> - packedRunsByProducerResultClass; -}; - -struct MaterializerState { - func::FuncOp func; - const MergeScheduleResult& schedule; - IRRewriter rewriter; - OperationFolder constantFolder; - int64_t& nextChannelId; - SmallVector classes; - DenseMap cpuToClass; - DenseMap> logicalInstancesByCpu; - DenseMap scheduledInstanceToLogicalSlots; - DenseMap logicalInstanceToScheduledChunk; - DenseSet materializedLogicalSlots; - - DenseMap, ProducerKeyInfo> producerDestClasses; - DenseMap, SameClassConsumerLookupKeyInfo> - sameClassConsumerIndex; - DenseMap projectedInputMatches; - DenseSet nonProjectedInputs; - DenseMap liveExternalUseCache; - DenseMap> batchOutputFragmentTypesCache; - DenseMap, llvm::DenseMapInfo> computeInstanceOutputsCache; - DenseMap, ProducerKeyInfo> projectedTransfers; - DenseMap> projectedExtractReplacements; - AvailableValueStore availableValues; - DenseMap hostReplacements; - DenseMap hostOutputOwners; - SmallVector pendingProjectedHostReceives; - SmallVector pendingScalarReceives; - DenseMap, ProducerKeyInfo> pendingScalarReceiveLookup; - DenseMap firstLateCommunicationOps; - int64_t nextCommunicationTraceId = 0; - DenseSet oldComputeOps; - - MaterializerState(func::FuncOp func, - const MergeScheduleResult& schedule, - int64_t& nextChannelId) - : func(func), - schedule(schedule), - rewriter(func.getContext()), - constantFolder(func.getContext()), - nextChannelId(nextChannelId) {} -}; - -bool isConstantLike(Value value) { - Operation* definingOp = value.getDefiningOp(); - return definingOp && definingOp->hasTrait(); -} - -bool isInsideOldCompute(Operation* op, const DenseSet& oldComputeOps) { - for (Operation* current = op; current; current = current->getParentOp()) - if (oldComputeOps.contains(current)) - return true; - return false; -} - -bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps); -ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance); - -bool hasLiveExternalUseCached(MaterializerState& state, Value value) { - auto cached = state.liveExternalUseCache.find(value); - if (cached != state.liveExternalUseCache.end()) - return cached->second; - bool live = hasLiveExternalUse(value, state.oldComputeOps); - state.liveExternalUseCache[value] = live; - return live; -} - -std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extract) { - if (extract.getMixedOffsets().empty()) - return std::nullopt; - - OpFoldResult offset = extract.getMixedOffsets().front(); - if (auto attr = dyn_cast(offset)) { - auto intAttr = dyn_cast(attr); - if (!intAttr || intAttr.getInt() < 0) - return std::nullopt; - return static_cast(intAttr.getInt()); - } - - auto value = cast(offset); - if (auto constantIndex = value.getDefiningOp()) { - if (constantIndex.value() < 0) - return std::nullopt; - return static_cast(constantIndex.value()); - } - - APInt constantValue; - if (matchPattern(value, m_ConstantInt(&constantValue))) { - if (constantValue.isNegative()) - return std::nullopt; - return static_cast(constantValue.getZExtValue()); - } - - return std::nullopt; -} - -ProducerKey -getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) { - return { - {batch.getOperation(), laneStart, laneCount}, - resultIndex - }; -} - -ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { - return getBatchLaneProducerKey(batch, 0, static_cast(batch.getLaneCount()), resultIndex); -} - -bool isWholeBatchProducerKey(ProducerKey key) { - auto batch = dyn_cast_or_null(key.instance.op); - return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 - && key.instance.laneCount == static_cast(batch.getLaneCount()); -} - -std::optional getContiguousProducerRangeForKeys(ArrayRef keys) { - if (keys.empty()) - return std::nullopt; - - ProducerKey first = keys.front(); - auto batch = dyn_cast_or_null(first.instance.op); - if (!batch) - return std::nullopt; - - SmallVector sorted(keys.begin(), keys.end()); - llvm::sort(sorted, [](ProducerKey lhs, ProducerKey rhs) { - return std::tie(lhs.instance.laneStart, lhs.instance.laneCount, lhs.resultIndex) - < std::tie(rhs.instance.laneStart, rhs.instance.laneCount, rhs.resultIndex); - }); - - uint32_t laneStart = sorted.front().instance.laneStart; - uint32_t nextLane = laneStart; - for (ProducerKey key : sorted) { - if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) - return std::nullopt; - if (key.instance.laneStart != nextLane) - return std::nullopt; - nextLane += key.instance.laneCount; - } - - uint32_t laneCount = nextLane - laneStart; - if (laneStart + laneCount > static_cast(batch.getLaneCount())) - return std::nullopt; - - return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); -} - -std::optional getPhysicallyContiguousProducerRangeForKeys(ArrayRef keys) { - if (keys.empty()) - return std::nullopt; - - ProducerKey first = keys.front(); - auto batch = dyn_cast_or_null(first.instance.op); - if (!batch || first.instance.laneCount == 0) - return std::nullopt; - - uint32_t laneStart = first.instance.laneStart; - uint32_t nextLane = laneStart; - for (ProducerKey key : keys) { - if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) - return std::nullopt; - if (key.instance.laneStart != nextLane) - return std::nullopt; - nextLane += key.instance.laneCount; - } - - uint32_t laneCount = nextLane - laneStart; - if (laneStart + laneCount > static_cast(batch.getLaneCount())) - return std::nullopt; - - return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); -} - -WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) { - return {sourceOp, resultIndex, classId}; -} - -WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) { - return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId); -} - -FailureOr getPackedBatchTensorType(Type laneType, size_t laneCount) { - auto tensorType = dyn_cast(laneType); - if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) - return failure(); - - SmallVector shape(tensorType.getShape()); - shape[0] *= static_cast(laneCount); - return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); -} - -LogicalResult verifyPackableFragmentType(Operation* anchor, Type fragmentType, size_t count, StringRef message) { - if (failed(getPackedBatchTensorType(fragmentType, count))) - return anchor->emitError(message); - return success(); -} - -ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, ComputeInstance logicalInstance) { - auto it = state.logicalInstanceToScheduledChunk.find(logicalInstance); - if (it != state.logicalInstanceToScheduledChunk.end()) - return it->second; - return logicalInstance; -} - -SmallVector -collectProducerKeysForDestinations(Value value, std::optional logicalConsumer = std::nullopt) { - // Destination collection works in the materializer's logical one-lane key domain. - // Whole-batch resultful producers are expanded into per-lane producer keys here. - SmallVector keys; - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return keys; - - while (auto extract = dyn_cast(definingOp)) { - Value source = extract.getSource(); - auto batch = dyn_cast_or_null(source.getDefiningOp()); - if (batch && batch.getNumResults() != 0) { - auto result = dyn_cast(source); - if (!result) - return {}; - - if (std::optional lane = getConstantFirstSliceOffset(extract)) { - if (*lane >= static_cast(batch.getLaneCount())) - return {}; - keys.push_back(getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber())); - return keys; - } - - return {}; - } - - value = source; - definingOp = value.getDefiningOp(); - if (!definingOp) - return {}; - } - - if (auto compute = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return {}; - keys.push_back({ - {compute.getOperation(), 0, 1}, - result.getResultNumber() - }); - return keys; - } - - if (auto batch = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return {}; - - if (batch.getNumResults() != 0) { - for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) - keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); - return keys; - } - - ComputeInstance chunk = getBatchChunkForLane(batch, result.getResultNumber()); - keys.push_back({chunk, static_cast(result.getResultNumber() - chunk.laneStart)}); - return keys; - } - - return keys; -} - -std::optional getInputRequestProducerKey(Value value, - std::optional logicalConsumer = std::nullopt) { - // Input resolution may request a whole-batch key for scalar consumers that read - // a complete resultful compute_batch value. - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return std::nullopt; - - while (auto extract = dyn_cast(definingOp)) { - Value source = extract.getSource(); - auto batch = dyn_cast_or_null(source.getDefiningOp()); - if (batch && batch.getNumResults() != 0) { - auto result = dyn_cast(source); - if (!result) - return std::nullopt; - - if (std::optional lane = getConstantFirstSliceOffset(extract)) - return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); - - return std::nullopt; - } - - value = source; - definingOp = value.getDefiningOp(); - if (!definingOp) - return std::nullopt; - } - - if (auto compute = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return std::nullopt; - return ProducerKey { - {compute.getOperation(), 0, 1}, - result.getResultNumber() - }; - } - - if (auto batch = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return std::nullopt; - - if (batch.getNumResults() != 0) - return getWholeBatchProducerKey(batch, result.getResultNumber()); - - return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; - } - - return std::nullopt; -} - -class CpuUnionFind { -public: - void insert(CpuId cpu) { parent.try_emplace(cpu, cpu); } - - CpuId find(CpuId cpu) { - insert(cpu); - CpuId p = parent.lookup(cpu); - if (p == cpu) - return cpu; - CpuId root = find(p); - parent[cpu] = root; - return root; - } - - void unite(CpuId lhs, CpuId rhs) { - CpuId lhsRoot = find(lhs); - CpuId rhsRoot = find(rhs); - if (lhsRoot == rhsRoot) - return; - if (rhsRoot < lhsRoot) - std::swap(lhsRoot, rhsRoot); - parent[rhsRoot] = lhsRoot; - } - -private: - DenseMap parent; -}; - -LogicalResult buildMaterializationWorkStreams(MaterializerState& state) { - DenseMap> scheduledInstancesByCpu; - for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { - state.oldComputeOps.insert(instance.op); - scheduledInstancesByCpu[cpu].push_back(instance); - state.logicalInstancesByCpu.try_emplace(cpu); - } - - for (auto& [cpu, scheduledInstances] : scheduledInstancesByCpu) { - llvm::sort(scheduledInstances, [&](const ComputeInstance& lhs, const ComputeInstance& rhs) { - auto lhsIt = state.schedule.computeToCpuSlotMap.find(lhs); - auto rhsIt = state.schedule.computeToCpuSlotMap.find(rhs); - assert(lhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); - assert(rhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); - return lhsIt->second < rhsIt->second; - }); - - SmallVector& logicalInstances = state.logicalInstancesByCpu[cpu]; - SlotId logicalSlot = 0; - for (const ComputeInstance& instance : scheduledInstances) { - LogicalSlotRange range {logicalSlot, 1}; - if (isa(instance.op)) - range.count = instance.laneCount; - - state.scheduledInstanceToLogicalSlots[instance] = range; - - if (isa(instance.op)) { - for (uint32_t localLane = 0; localLane < instance.laneCount; ++localLane, ++logicalSlot) { - uint32_t logicalLane = instance.laneStart + localLane; - ComputeInstance logicalInstance {instance.op, logicalLane, 1}; - logicalInstances.push_back(logicalInstance); - state.logicalInstanceToScheduledChunk[logicalInstance] = instance; - } - continue; - } - - logicalInstances.push_back(instance); - ++logicalSlot; - } - } - - return success(); -} - -LogicalResult buildMaterializationClassesFromScheduleEquivalence(MaterializerState& state) { - DenseSet usedCpus; - for (const auto& entry : state.schedule.cpuToLastComputeMap) - usedCpus.insert(entry.first); - for (const auto& entry : state.schedule.computeToCpuMap) - usedCpus.insert(entry.second); - - CpuUnionFind unionFind; - for (CpuId cpu : usedCpus) - unionFind.insert(cpu); - - for (const auto& [cpu, equivalentCpus] : state.schedule.equivalentClass) { - if (!usedCpus.contains(cpu)) - continue; - for (CpuId equivalentCpu : equivalentCpus) - if (usedCpus.contains(equivalentCpu)) - unionFind.unite(cpu, equivalentCpu); - } - - DenseMap> groupsByRoot; - for (CpuId cpu : usedCpus) - groupsByRoot[unionFind.find(cpu)].push_back(cpu); - - SmallVector roots; - roots.reserve(groupsByRoot.size()); - for (const auto& entry : groupsByRoot) - roots.push_back(entry.first); - llvm::sort(roots); - - state.classes.reserve(roots.size()); - for (CpuId root : roots) { - MaterializedClass materializedClass; - materializedClass.id = state.classes.size(); - materializedClass.cpus = groupsByRoot.lookup(root); - llvm::sort(materializedClass.cpus); - materializedClass.isBatch = materializedClass.cpus.size() > 1; - for (auto [lane, cpu] : llvm::enumerate(materializedClass.cpus)) { - materializedClass.cpuToLane[cpu] = static_cast(lane); - state.cpuToClass[cpu] = materializedClass.id; - } - state.classes.push_back(std::move(materializedClass)); - } - - return success(); -} - -LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState& state) { - for (const MaterializedClass& materializedClass : state.classes) { - if (materializedClass.cpus.empty()) - continue; - - auto referenceIt = state.logicalInstancesByCpu.find(materializedClass.cpus.front()); - if (referenceIt == state.logicalInstancesByCpu.end()) - return state.func.emitError("missing logical stream for materialized class reference CPU"); - - ArrayRef referenceStream(referenceIt->second); - for (CpuId cpu : materializedClass.cpus) { - auto streamIt = state.logicalInstancesByCpu.find(cpu); - if (streamIt == state.logicalInstancesByCpu.end()) - return state.func.emitError("missing logical stream for materialized class CPU"); - - ArrayRef stream(streamIt->second); - if (stream.size() != referenceStream.size()) - return state.func.emitError("materialized class CPUs have mismatched logical stream lengths"); - - for (auto [slot, zipped] : llvm::enumerate(llvm::zip(referenceStream, stream))) { - const ComputeInstance& referenceInstance = std::get<0>(zipped); - const ComputeInstance& currentInstance = std::get<1>(zipped); - if (referenceInstance.op != currentInstance.op) - return state.func.emitError("materialized class logical slot source op mismatch"); - if (isa(referenceInstance.op) != isa(currentInstance.op)) - return state.func.emitError("materialized class logical slot batch/scalar mismatch"); - (void) slot; - } - } - } - - return success(); -} - -LogicalResult forEachLogicalConsumerInMaterializationOrder( - MaterializerState& state, - llvm::function_ref - callback) { - for (const ComputeInstance& scheduledInstance : state.schedule.dominanceOrderCompute) { - auto cpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return scheduledInstance.op->emitError("missing CPU assignment for scheduled logical-slot iteration"); - - auto rangeIt = state.scheduledInstanceToLogicalSlots.find(scheduledInstance); - if (rangeIt == state.scheduledInstanceToLogicalSlots.end()) - return scheduledInstance.op->emitError("missing logical slot range for scheduled logical-slot iteration"); - - CpuId cpu = cpuIt->second; - ClassId classId = state.cpuToClass.lookup(cpu); - LogicalSlotRange range = rangeIt->second; - auto streamIt = state.logicalInstancesByCpu.find(cpu); - if (streamIt == state.logicalInstancesByCpu.end()) - return scheduledInstance.op->emitError("missing logical stream for CPU"); - for (SlotId logicalSlot = range.start; logicalSlot < range.start + range.count; ++logicalSlot) { - if (logicalSlot >= streamIt->second.size()) - return scheduledInstance.op->emitError("missing logical slot materialization instance"); - if (failed(callback(cpu, classId, scheduledInstance, streamIt->second[logicalSlot], logicalSlot))) - return failure(); - } - } - - return success(); -} - -bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps); - -LogicalResult collectHostOutputs(MaterializerState& state) { - DenseSet seenOutputs; - SmallVector orderedOutputs; - DenseMap preferredOwners; - - for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { - auto cpuIt = state.schedule.computeToCpuMap.find(instance); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - - ClassId classId = state.cpuToClass.lookup(cpuIt->second); - MaterializedClass& materializedClass = state.classes[classId]; - for (Value output : getComputeInstanceOutputValuesCached(state, instance)) { - if (!hasLiveExternalUseCached(state, output)) - continue; - - if (seenOutputs.insert(output).second) { - orderedOutputs.push_back(output); - preferredOwners[output] = classId; - continue; - } - - auto batch = dyn_cast_or_null(output.getDefiningOp()); - if (!batch || batch.getNumResults() == 0) - continue; - - ClassId currentOwner = preferredOwners.lookup(output); - bool terminalHost = isTerminalHostBatchOutput(output, state.oldComputeOps); - if (terminalHost) { - // Terminal resultful batch outputs are still published through scalar - // host-output slots unless the materialized batch class owns the output - // directly. Selecting an arbitrary batch class as the host owner would - // require a projection-aware batch publication path, which the - // materializer does not currently implement. - if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) - preferredOwners[output] = classId; - continue; - } - - if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) - preferredOwners[output] = classId; - } - } - - for (MaterializedClass& materializedClass : state.classes) { - materializedClass.hostOutputs.clear(); - materializedClass.hostOutputToResultIndex.clear(); - } - state.hostOutputOwners.clear(); - - for (Value output : orderedOutputs) { - ClassId ownerClassId = preferredOwners.lookup(output); - MaterializedClass& ownerClass = state.classes[ownerClassId]; - ownerClass.hostOutputToResultIndex[output] = ownerClass.hostOutputs.size(); - ownerClass.hostOutputs.push_back(output); - state.hostOutputOwners[output] = ownerClassId; - } - - return success(); -} - -LogicalResult createEmptyMaterializedOps(MaterializerState& state) { - Location loc = state.func.getLoc(); - Block& funcBlock = state.func.getBody().front(); - - Operation* firstOldCompute = nullptr; - for (Operation& op : funcBlock) { - if (state.oldComputeOps.contains(&op)) { - firstOldCompute = &op; - break; - } - } - - if (firstOldCompute) - state.rewriter.setInsertionPoint(firstOldCompute); - else - state.rewriter.setInsertionPointToStart(&funcBlock); - - for (MaterializedClass& materializedClass : state.classes) { - SmallVector resultTypes; - resultTypes.reserve(materializedClass.hostOutputs.size()); - for (Value output : materializedClass.hostOutputs) - resultTypes.push_back(output.getType()); - - if (!materializedClass.isBatch) { - auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); - compute.getProperties().setOperandSegmentSizes({0, 0}); - auto coreIdAttr = - pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); - if (failed(coreIdAttr)) - return failure(); - compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); - Block* body = state.rewriter.createBlock(&compute.getBody()); - state.rewriter.setInsertionPointToEnd(body); - SmallVector placeholderOutputs; - placeholderOutputs.reserve(resultTypes.size()); - for (Type resultType : resultTypes) { - auto tensorType = dyn_cast(resultType); - if (!tensorType || !tensorType.hasStaticShape()) { - compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); - return failure(); - } - placeholderOutputs.push_back( - tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); - } - SpatYieldOp::create(state.rewriter, loc, ValueRange(placeholderOutputs)); - materializedClass.op = compute.getOperation(); - materializedClass.body = body; - state.rewriter.setInsertionPointAfter(compute.getOperation()); - continue; - } - - auto batchLaneCountAttr = pim::getCheckedI32Attr( - state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); - if (failed(batchLaneCountAttr)) - return failure(); - auto batch = SpatScheduledComputeBatch::create(state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); - batch.getProperties().setOperandSegmentSizes({0, 0}); - auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); - if (failed(coreIds)) - return failure(); - batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(*coreIds)); - - SmallVector blockArgTypes {state.rewriter.getIndexType()}; - SmallVector blockArgLocs {loc}; - llvm::append_range(blockArgTypes, resultTypes); - blockArgLocs.append(resultTypes.size(), loc); - Block* body = - state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - state.rewriter.setInsertionPointToEnd(body); - if (resultTypes.empty()) - SpatYieldOp::create(state.rewriter, loc, ValueRange {}); - else - SpatInParallelOp::create(state.rewriter, loc); - materializedClass.op = batch.getOperation(); - materializedClass.body = body; - state.rewriter.setInsertionPointAfter(batch.getOperation()); - } - - return success(); -} - -BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { - auto it = materializedClass.weightArgs.find(weight); - if (it != materializedClass.weightArgs.end()) - return it->second; - - unsigned weightIndex = materializedClass.weights.size(); - materializedClass.weights.push_back(weight); - - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc()); - assert(arg && "expected compute body while inserting a weight"); - materializedClass.weightArgs[weight] = std::get<1>(*arg); - return std::get<1>(*arg); - } - - auto batch = cast(materializedClass.op); - auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc()); - assert(arg && "expected compute_batch body while inserting a weight argument"); - materializedClass.weightArgs[weight] = std::get<1>(*arg); - return std::get<1>(*arg); -} - -BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) { - auto it = materializedClass.inputArgs.find(input); - if (it != materializedClass.inputArgs.end()) - return it->second; - - materializedClass.inputs.push_back(input); - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); - assert(arg && "expected compute body while inserting an input"); - materializedClass.inputArgs[input] = std::get<1>(*arg); - return std::get<1>(*arg); - } - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); - assert(arg && "expected compute_batch body while inserting an input argument"); - materializedClass.inputArgs[input] = std::get<1>(*arg); - return std::get<1>(*arg); - } - llvm_unreachable("Cannot reach here"); -} - -Region* getParentRegion(Value value) { - if (auto blockArg = dyn_cast(value)) - return blockArg.getOwner()->getParent(); - if (Operation* definingOp = value.getDefiningOp()) - return definingOp->getParentRegion(); - return nullptr; -} - -bool isDefinedInsideRegion(Value value, Region& region) { - Region* parentRegion = getParentRegion(value); - return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); -} - -Operation* getEnclosingSpatialComputeLikeOp(Value value) { - Block* block = nullptr; - if (auto blockArg = dyn_cast(value)) - block = blockArg.getOwner(); - else if (Operation* definingOp = value.getDefiningOp()) - block = definingOp->getBlock(); - - if (!block) - return nullptr; - - for (Operation* current = block->getParentOp(); current; current = current->getParentOp()) - if (isa(current)) - return current; - return nullptr; -} - -bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { - if (!isa(value.getType())) - return true; - if (isConstantLike(value)) - return true; - - Region& targetRegion = *targetClass.body->getParent(); - return isDefinedInsideRegion(value, targetRegion); -} - -bool isTensorValueDefinedInDifferentMaterializedClass(Value value, const MaterializedClass& targetClass) { - if (!isa(value.getType()) || isTensorValueLocalToMaterializedClass(value, targetClass)) - return false; - - Operation* owner = getEnclosingSpatialComputeLikeOp(value); - return owner && owner != targetClass.op; -} - -std::optional getRegionIndexInParentOp(Region* region) { - Operation* parent = region ? region->getParentOp() : nullptr; - if (!parent) - return std::nullopt; - - for (auto [index, candidate] : llvm::enumerate(parent->getRegions())) - if (&candidate == region) - return static_cast(index); - return std::nullopt; -} - -std::optional getBlockIndexInRegion(Block* block) { - Region* region = block ? block->getParent() : nullptr; - if (!region) - return std::nullopt; - - for (auto [index, candidate] : llvm::enumerate(region->getBlocks())) - if (&candidate == block) - return static_cast(index); - return std::nullopt; -} - -Block* getBlockByIndex(Region& region, unsigned blockIndex) { - unsigned index = 0; - for (Block& block : region) { - if (index == blockIndex) - return █ - ++index; - } - return nullptr; -} - -static bool isValueLegalInMaterializedClassBody(Value value, const MaterializedClass& targetClass) { - if (isConstantLike(value)) - return true; - - Region& targetRegion = *targetClass.body->getParent(); - return isDefinedInsideRegion(value, targetRegion); -} - -std::string stringifyOperationForMaterializerDebug(Operation* op) { - if (!op) - return std::string(""); - std::string storage; - llvm::raw_string_ostream stream(storage); - op->print(stream); - return storage; -} - -std::string stringifyValueForMaterializerDebug(Value value) { - std::string storage; - llvm::raw_string_ostream stream(storage); - value.print(stream); - return storage; -} - -std::string truncateMaterializerDebugString(std::string text, size_t limit = 1200) { - for (char& ch : text) - if (ch == '\n' || ch == '\r' || ch == '\t') - ch = ' '; - - if (text.size() <= limit) - return text; - text.resize(limit); - text += "..."; - return text; -} - -std::string formatMaterializerOperandListInline(Operation* op, const MaterializedClass& targetClass) { - if (!op) - return std::string(""); - - std::string storage; - llvm::raw_string_ostream stream(storage); - for (OpOperand& operand : op->getOpOperands()) { - if (operand.getOperandNumber() != 0) - stream << " | "; - Value value = operand.get(); - stream << "operand#" << operand.getOperandNumber() << " type=" << value.getType() - << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) - << " value=" << stringifyValueForMaterializerDebug(value); - if (auto blockArg = dyn_cast(value)) { - stream << " blockArg#" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - stream << " ownerOp='" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - stream << " definingOp='" << definingOp->getName() << "'"; - } - } - return truncateMaterializerDebugString(stream.str()); -} - -std::string formatMaterializerParentChainInline(Operation* op) { - if (!op) - return std::string(""); - - std::string storage; - llvm::raw_string_ostream stream(storage); - unsigned depth = 0; - for (Operation* current = op; current; current = current->getParentOp()) { - if (depth != 0) - stream << " <- "; - stream << "[" << depth++ << "]" << current->getName(); - } - return truncateMaterializerDebugString(stream.str()); -} - -void attachMaterializerOperationPrintNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { - if (!op) - return; - diagnostic.attachNote(op->getLoc()) << label << ":\n" << stringifyOperationForMaterializerDebug(op); -} - -void attachMaterializerParentChainNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { - if (!op) - return; - - std::string storage; - llvm::raw_string_ostream stream(storage); - unsigned depth = 0; - for (Operation* current = op; current; current = current->getParentOp()) - stream << " [" << depth++ << "] " << current->getName() << "\n"; - - diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); -} - -void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, - Operation* op, - const MaterializedClass& targetClass, - StringRef label) { - if (!op) - return; - - std::string storage; - llvm::raw_string_ostream stream(storage); - for (OpOperand& operand : op->getOpOperands()) { - Value value = operand.get(); - stream << " operand#" << operand.getOperandNumber() << " type=" << value.getType() - << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) - << " value=" << stringifyValueForMaterializerDebug(value); - if (auto blockArg = dyn_cast(value)) { - stream << " blockArg#" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - stream << " ownerOp='" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - stream << " definingOp='" << definingOp->getName() << "'"; - } - stream << "\n"; - } - - diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); -} - -void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value value, StringRef label) { - if (auto blockArg = dyn_cast(value)) { - if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic.attachNote(owner->getLoc()) - << label << " is block argument #" << blockArg.getArgNumber() << " of '" << owner->getName() - << "' with type " << blockArg.getType(); - else - diagnostic.attachNote(UnknownLoc::get(value.getContext())) - << label << " is a top-level block argument #" << blockArg.getArgNumber() - << " with type " << blockArg.getType(); - return; - } - - if (Operation* definingOp = value.getDefiningOp()) { - diagnostic.attachNote(definingOp->getLoc()) - << label << " is defined by '" << definingOp->getName() << "' with result type " << value.getType(); - return; - } - - diagnostic.attachNote(UnknownLoc::get(value.getContext())) - << label << " has no defining operation and is not a block argument, type " << value.getType(); -} - -void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { - Block& body = *targetClass.body; - diagnostic.attachNote(targetClass.op->getLoc()) - << "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName() - << "' body has " << body.getNumArguments() << " block arguments and " - << std::distance(body.begin(), body.end()) << " top-level operations"; -} - -FailureOr rematerializeIndexValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Location loc, - IRMapping* mapper = nullptr); - -FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& state, - MaterializedClass& targetClass, - OpFoldResult value, - Location loc, - IRMapping* mapper = nullptr) { - if (auto attr = dyn_cast(value)) - return OpFoldResult(attr); - - FailureOr rematerialized = rematerializeIndexValueInClass(state, targetClass, cast(value), loc, mapper); - if (failed(rematerialized)) - return failure(); - return OpFoldResult(*rematerialized); -} - -FailureOr rematerializeIndexValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Location loc, - IRMapping* mapper) { - Value originalValue = value; - bool mapperHadOriginalValue = false; - Value mappedOriginalValue; - - if (mapper && mapper->contains(value)) { - mapperHadOriginalValue = true; - Value mapped = mapper->lookup(value); - mappedOriginalValue = mapped; - if (isValueLegalInMaterializedClassBody(mapped, targetClass) || isConstantLike(mapped)) - return mapped; - value = mapped; - } - - if (isValueLegalInMaterializedClassBody(value, targetClass)) - return value; - - if (!value.getType().isIndex()) - return targetClass.op->emitError("cannot rematerialize non-index external value in materialized class body") - << " type=" << value.getType(); - - if (auto constantIndex = value.getDefiningOp()) - return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantIndex.value()); - - APInt constantValue; - if (matchPattern(value, m_ConstantInt(&constantValue))) { - if (!constantValue.isSignedIntN(64)) - return targetClass.op->emitError("cannot rematerialize out-of-range index constant") - << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); - return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantValue.getSExtValue()); - } - - if (auto affineApply = value.getDefiningOp()) { - SmallVector remappedOperands; - remappedOperands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, operand, loc, mapper); - if (failed(remapped)) - return failure(); - remappedOperands.push_back(*remapped); - } - return createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func); - } - - if (auto addOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, addOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, addOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::AddIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto subOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, subOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, subOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::SubIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto mulOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::MulIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto divOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, divOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, divOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::DivUIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto extractOp = value.getDefiningOp()) { - SmallVector remappedIndices; - remappedIndices.reserve(extractOp.getIndices().size()); - for (Value index : extractOp.getIndices()) { - FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, index, loc, mapper); - if (failed(remapped)) - return failure(); - remappedIndices.push_back(*remapped); - } - - Value tensor = extractOp.getTensor(); - if (!isConstantLike(tensor) && !isValueLegalInMaterializedClassBody(tensor, targetClass)) - return targetClass.op->emitError("cannot rematerialize indexed table lookup from external non-constant tensor") - << " tensorType=" << tensor.getType(); - return tensor::ExtractOp::create(state.rewriter, loc, tensor, remappedIndices).getResult(); - } - - if (auto blockArg = dyn_cast(value)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body"); - diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() - << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; - if (Operation* owner = blockArg.getOwner()->getParentOp()) { - diagnostic << " ownerOp='" << owner->getName() << "'"; - diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) << "\""; - diagnostic << " ownerChain=\"" << formatMaterializerParentChainInline(owner) << "\""; - } - diagnostic << " targetIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; - if (mapper) { - diagnostic << " mapperPresent=1 mapperHadOriginal=" << (mapperHadOriginalValue ? 1 : 0); - if (mapperHadOriginalValue) - diagnostic << " mappedType=" << mappedOriginalValue.getType(); - } else { - diagnostic << " mapperPresent=0"; - } - attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); - if (value != originalValue) - attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); - if (mapperHadOriginalValue && mappedOriginalValue != value) - attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value"); - if (Operation* owner = blockArg.getOwner()->getParentOp()) { - attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op"); - attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain"); - } - attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return failure(); - } - - InFlightDiagnostic diagnostic = - targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body"); - diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='" - << targetClass.op->getName() << "'"; - attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); - if (value != originalValue) - attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return failure(); -} - -InFlightDiagnostic emitNonLocalMaterializedClassValueDiagnostic(Operation* anchor, - const MaterializedClass& targetClass, - StringRef context, - Value value, - std::optional producer = std::nullopt) { - InFlightDiagnostic diagnostic = anchor->emitError(context) << " into target class " << targetClass.id; - - if (producer) { - diagnostic << " from '" << producer->instance.op->getName() << "' resultIndex=" << producer->resultIndex - << " laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount; - } else if (auto result = dyn_cast(value)) { - diagnostic << " from '" << result.getOwner()->getName() << "' resultIndex=" << result.getResultNumber(); - } else if (auto blockArg = dyn_cast(value)) { - diagnostic << " from block argument #" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic << " of '" << owner->getName() << "'"; - } - - if (Operation* definingOp = value.getDefiningOp()) - diagnostic.attachNote(definingOp->getLoc()) << "offending tensor producer is '" << definingOp->getName() << "'"; - return diagnostic; -} - -FailureOr rematerializeTensorValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - IRMapping* mapper) { - auto extractSlice = value.getDefiningOp(); - if (extractSlice) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, targetClass, extractSlice.getSource(), anchor, context, std::nullopt, mapper); - if (failed(localizedSource)) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(extractSlice.getMixedOffsets().size()); - sizes.reserve(extractSlice.getMixedSizes().size()); - strides.reserve(extractSlice.getMixedStrides().size()); - - for (OpFoldResult offset : extractSlice.getMixedOffsets()) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, offset, anchor->getLoc(), mapper); - if (failed(localized)) - return failure(); - offsets.push_back(*localized); - } - for (OpFoldResult size : extractSlice.getMixedSizes()) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, size, anchor->getLoc(), mapper); - if (failed(localized)) - return failure(); - sizes.push_back(*localized); - } - for (OpFoldResult stride : extractSlice.getMixedStrides()) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, stride, anchor->getLoc(), mapper); - if (failed(localized)) - return failure(); - strides.push_back(*localized); - } - - return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides) - .getResult(); - } - - if (auto collapseShape = value.getDefiningOp()) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, targetClass, collapseShape.getSrc(), anchor, context, std::nullopt, mapper); - if (failed(localizedSource)) - return failure(); - return tensor::CollapseShapeOp::create( - state.rewriter, anchor->getLoc(), *localizedSource, collapseShape.getReassociationIndices()) - .getResult(); - } - - return failure(); -} - -FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - std::optional producer, - IRMapping* mapper) { - if (mapper && mapper->contains(value)) - value = mapper->lookup(value); - - if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) - return value; - - if (value.getDefiningOp() || value.getDefiningOp()) { - FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); - if (failed(rematerialized)) - return failure(); - return *rematerialized; - } - - if (isTensorValueDefinedInDifferentMaterializedClass(value, targetClass)) { - emitNonLocalMaterializedClassValueDiagnostic(anchor, targetClass, context, value, producer); - return failure(); - } - - return appendInput(state, targetClass, value); -} - -std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, - Operation* anchor, - BlockArgument externalArg) { - Block* sourceBlock = externalArg.getOwner(); - Region* sourceRegion = sourceBlock ? sourceBlock->getParent() : nullptr; - Operation* sourceParent = sourceRegion ? sourceRegion->getParentOp() : nullptr; - if (!sourceParent || !anchor) - return std::nullopt; - - std::optional sourceRegionIndex = getRegionIndexInParentOp(sourceRegion); - std::optional sourceBlockIndex = getBlockIndexInRegion(sourceBlock); - if (!sourceRegionIndex || !sourceBlockIndex) - return std::nullopt; - - for (Operation* current = anchor->getParentOp(); current && current != targetClass.op; - current = current->getParentOp()) { - if (current->getName() != sourceParent->getName()) - continue; - if (current->getNumRegions() <= *sourceRegionIndex) - continue; - - Region& localRegion = current->getRegion(*sourceRegionIndex); - Block* localBlock = getBlockByIndex(localRegion, *sourceBlockIndex); - if (!localBlock || localBlock->getNumArguments() <= externalArg.getArgNumber()) - continue; - - BlockArgument localArg = localBlock->getArgument(externalArg.getArgNumber()); - if (localArg.getType() != externalArg.getType()) - continue; - if (!isValueLegalInMaterializedClassBody(localArg, targetClass)) - continue; - return localArg; - } - - return std::nullopt; -} - -FailureOr localizeMaterializedClassOperand(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef tensorContext, - StringRef genericContext, - IRMapping* mapper) { - if (mapper && mapper->contains(value)) - value = mapper->lookup(value); - - if (auto blockArg = dyn_cast(value)) - if (std::optional localArg = mapExternalRegionBlockArgumentToLocalClone(targetClass, anchor, blockArg)) - return *localArg; - - if (isa(value.getType())) - return materializeTensorValueForMaterializedClassUse(state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); - - if (isValueLegalInMaterializedClassBody(value, targetClass)) - return value; - - if (value.getType().isIndex()) - return rematerializeIndexValueInClass(state, targetClass, value, anchor->getLoc(), mapper); - - InFlightDiagnostic diagnostic = anchor->emitError(genericContext); - diagnostic << " type=" << value.getType(); - if (auto blockArg = dyn_cast(value)) { - diagnostic << " blockArg#" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic.attachNote(owner->getLoc()) << "block argument belongs to '" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - diagnostic.attachNote(definingOp->getLoc()) << "unsupported external operand producer is '" << definingOp->getName() - << "'"; - } - return failure(); -} - -// ----------------------------------------------------------------------------- -// Tensor packing helpers. -// ----------------------------------------------------------------------------- - -struct Dim0SliceParams { - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -Dim0SliceParams -buildDim0SliceParams(OpBuilder& builder, RankedTensorType referenceType, OpFoldResult firstOffset, int64_t firstSize) { - Dim0SliceParams params; - params.offsets.reserve(referenceType.getRank()); - params.sizes.reserve(referenceType.getRank()); - params.strides.reserve(referenceType.getRank()); - - params.offsets.push_back(firstOffset); - params.sizes.push_back(builder.getIndexAttr(firstSize)); - params.strides.push_back(builder.getIndexAttr(1)); - - for (int64_t dim = 1; dim < referenceType.getRank(); ++dim) { - params.offsets.push_back(builder.getIndexAttr(0)); - params.sizes.push_back(builder.getIndexAttr(referenceType.getDimSize(dim))); - params.strides.push_back(builder.getIndexAttr(1)); - } - - return params; -} - -Value createDim0ExtractSlice( - MaterializerState& state, Location loc, Value source, OpFoldResult firstOffset, int64_t firstSize) { - auto sourceType = cast(source.getType()); - Dim0SliceParams params = buildDim0SliceParams(state.rewriter, sourceType, firstOffset, firstSize); - return tensor::ExtractSliceOp::create(state.rewriter, loc, source, params.offsets, params.sizes, params.strides) - .getResult(); -} - -FailureOr createDim0ExtractSliceInClass(MaterializerState& state, - MaterializedClass& targetClass, - Location loc, - Value source, - OpFoldResult firstOffset, - int64_t firstSize) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - source, - targetClass.op, - "createDim0ExtractSliceInClass tried to reuse a tensor from another materialized class"); - if (failed(localizedSource)) - return failure(); - FailureOr localizedOffset = - rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); - if (failed(localizedOffset)) - return failure(); - return createDim0ExtractSlice(state, loc, *localizedSource, *localizedOffset, firstSize); -} - -Value createStaticExtractSlice(MaterializerState& state, - Location loc, - Value source, - ArrayRef sliceOffsets, - ArrayRef resultShape) { - auto sourceType = cast(source.getType()); - assert(sliceOffsets.size() == static_cast(sourceType.getRank()) && "offset rank mismatch"); - assert(resultShape.size() == static_cast(sourceType.getRank()) && "result rank mismatch"); - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(sourceType.getRank()); - sizes.reserve(sourceType.getRank()); - strides.reserve(sourceType.getRank()); - - for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - offsets.push_back(sliceOffsets[dim]); - sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim])); - strides.push_back(state.rewriter.getIndexAttr(1)); - } - - return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); -} - -FailureOr createStaticExtractSliceInClass(MaterializerState& state, - MaterializedClass& targetClass, - Location loc, - Value source, - ArrayRef sliceOffsets, - ArrayRef resultShape) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - source, - targetClass.op, - "createStaticExtractSliceInClass tried to reuse a tensor from another materialized class"); - if (failed(localizedSource)) - return failure(); - - SmallVector localizedOffsets; - localizedOffsets.reserve(sliceOffsets.size()); - for (OpFoldResult offset : sliceOffsets) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); - if (failed(localized)) - return failure(); - localizedOffsets.push_back(*localized); - } - return createStaticExtractSlice(state, loc, *localizedSource, localizedOffsets, resultShape); -} - -Value createIndexedIndexValue(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc, - std::optional preferredPeriod = std::nullopt, - bool allowExhaustiveTiledSearch = true); - -FailureOr> buildProjectedFragmentOffsetsInClass(MaterializerState& state, - MaterializedClass& targetClass, - const ProjectedTransferDescriptor& descriptor, - Value flatFragmentIndex, - Location loc) { - FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, flatFragmentIndex, loc); - if (failed(localizedIndex)) - return failure(); - SmallVector fragmentOffsets; - fragmentOffsets.reserve(descriptor.layout.fragmentShape.size()); - for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) - fragmentOffsets.push_back(createIndexedIndexValue(state, - targetClass.op, - dimOffsets, - *localizedIndex, - loc, - static_cast(descriptor.layout.payloadFragmentCount), - /*allowExhaustiveTiledSearch=*/false)); - return fragmentOffsets; -} - -Value createDim0InsertSlice( - MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { - auto fragmentType = cast(fragment.getType()); - Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); - return tensor::InsertSliceOp::create( - state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides) - .getResult(); -} - -FailureOr createDim0InsertSliceInClass(MaterializerState& state, - MaterializedClass& targetClass, - Location loc, - Value fragment, - Value destination, - OpFoldResult firstOffset) { - FailureOr localizedFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fragment, - targetClass.op, - "createDim0InsertSliceInClass tried to reuse a fragment tensor from another materialized class"); - if (failed(localizedFragment)) - return failure(); - FailureOr localizedDestination = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - destination, - targetClass.op, - "createDim0InsertSliceInClass tried to reuse a destination tensor from another materialized class"); - if (failed(localizedDestination)) - return failure(); - FailureOr localizedOffset = - rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); - if (failed(localizedOffset)) - return failure(); - return createDim0InsertSlice(state, loc, *localizedFragment, *localizedDestination, *localizedOffset); -} - -void createDim0ParallelInsertSlice( - MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { - auto fragmentType = cast(fragment.getType()); - Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); - tensor::ParallelInsertSliceOp::create( - state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides); -} - -Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc) { - if (dim0Size == 1) - return index; - - Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size); - return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); -} - -FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value index, - int64_t dim0Size, - Location loc) { - FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, index, loc); - if (failed(localizedIndex)) - return failure(); - if (dim0Size == 1) - return *localizedIndex; - - Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size); - return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult(); -} - -bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { - return lhs.instance.op == rhs.instance.op && lhs.resultIndex == rhs.resultIndex; -} - -bool containsProducerKey(ProducerKey outer, ProducerKey inner) { - if (!sameProducerResult(outer, inner)) - return false; - if (!isa(outer.instance.op)) - return false; - if (outer.instance.laneCount == 0 || inner.instance.laneCount == 0) - return false; - - uint32_t outerStart = outer.instance.laneStart; - uint32_t outerEnd = outerStart + outer.instance.laneCount; - uint32_t innerStart = inner.instance.laneStart; - uint32_t innerEnd = innerStart + inner.instance.laneCount; - - return outerStart <= innerStart && innerEnd <= outerEnd; -} - -std::optional extractPackedProducerSlice(MaterializerState& state, - MaterializedClass& materializedClass, - ProducerKey packedKey, - Value packed, - ProducerKey requestedKey) { - if (!containsProducerKey(packedKey, requestedKey)) - return std::nullopt; - - auto packedType = dyn_cast(packed.getType()); - if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0) - return std::nullopt; - - if (packedKey.instance.laneCount == 0) - return std::nullopt; - - int64_t packedRows = packedType.getDimSize(0); - if (packedRows % static_cast(packedKey.instance.laneCount) != 0) - return std::nullopt; - - int64_t rowsPerLane = packedRows / static_cast(packedKey.instance.laneCount); - int64_t rowOffset = - static_cast(requestedKey.instance.laneStart - packedKey.instance.laneStart) * rowsPerLane; - int64_t rowCount = static_cast(requestedKey.instance.laneCount) * rowsPerLane; - - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - - Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); - FailureOr slice = - createDim0ExtractSliceInClass(state, materializedClass, materializedClass.op->getLoc(), packed, firstOffset, rowCount); - if (failed(slice)) - return std::nullopt; - return *slice; -} - -std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { - auto producerIt = exactValues.find(key); - if (producerIt == exactValues.end()) - return std::nullopt; - - auto valueIt = producerIt->second.find(classId); - if (valueIt == producerIt->second.end()) - return std::nullopt; - - return valueIt->second; -} - -FailureOr getPackedSliceForRunIndex(MaterializerState& state, - MaterializedClass& targetClass, - Value packed, - RankedTensorType fragmentType, - size_t index, - Location loc) { - int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); - Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); - return createDim0ExtractSliceInClass(state, targetClass, loc, packed, firstOffset, fragmentType.getDimSize(0)); -} - -FailureOr createReceiveConcatLoop(MaterializerState& state, - MaterializedClass& targetClass, - RankedTensorType concatType, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc); - -using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; -using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; - -FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc); - -SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run); - -bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { - return run.kind == PackedScalarRunKind::DeferredLocalCompute; -} - -size_t getPackedScalarRunReceiveCount(const PackedScalarRunValue& run) { - size_t count = 0; - for (const PackedScalarRunSlot& slot : run.slots) - count += slot.keys.size(); - return count; -} - -LogicalResult validatePackedScalarRunMetadata(Operation* anchor, const PackedScalarRunValue& run) { - if (run.kind == PackedScalarRunKind::DeferredLocalCompute) - return success(); - - size_t receiveCount = getPackedScalarRunReceiveCount(run); - - if (receiveCount == 0) - return anchor->emitError("packed scalar run has no receives"); - - if (failed(run.messages.verify(anchor))) - return failure(); - - if (run.messages.size() != receiveCount) - return anchor->emitError("packed scalar run receive metadata count is inconsistent"); - - return success(); -} - -FailureOr materializePackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc) { - if (run.packed) - return run.packed; - - if (run.kind == PackedScalarRunKind::Materialized) - return targetClass.op->emitError("materialized packed scalar run has no packed value"); - - if (isDeferredLocalPackedScalarRun(run)) - return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc); - - if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) - return failure(); - - FailureOr fullPackedType = - getPackedBatchTensorType(run.fragmentType, getPackedScalarRunReceiveCount(run)); - if (failed(fullPackedType)) - return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); - - auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc); - if (failed(packed)) - return failure(); - run.packed = *packed; - return run.packed; -} - -std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { - for (PackedScalarRunValue& run : packedScalarRuns) { - if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) - continue; - - size_t flattenedIndexBase = 0; - for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { - std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); - if (contiguousKey && containsProducerKey(*contiguousKey, key)) { - FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); - if (failed(slotPackedType)) - return std::nullopt; - - MaterializedClass& materializedClass = state.classes[classId]; - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - - FailureOr packed = - materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); - if (failed(packed)) - return std::nullopt; - FailureOr slotPacked = - getPackedSliceForRunIndex(state, materializedClass, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); - if (failed(slotPacked)) - return std::nullopt; - - if (*contiguousKey == key) { - record(key, classId, *slotPacked); - return *slotPacked; - } - - std::optional sliced = - extractPackedProducerSlice(state, materializedClass, *contiguousKey, *slotPacked, key); - if (!sliced) - return std::nullopt; - - record(key, classId, *sliced); - return *sliced; - } - - auto keyIt = llvm::find(slot.keys, key); - if (keyIt == slot.keys.end()) { - flattenedIndexBase += slot.keys.size(); - continue; - } - - MaterializedClass& materializedClass = state.classes[classId]; - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - - FailureOr packed = - materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); - if (failed(packed)) - return std::nullopt; - size_t flattenedIndex = flattenedIndexBase + static_cast(std::distance(slot.keys.begin(), keyIt)); - FailureOr sliced = - getPackedSliceForRunIndex(state, materializedClass, *packed, run.fragmentType, flattenedIndex, (*packed).getLoc()); - if (failed(sliced)) - return std::nullopt; - record(key, classId, *sliced); - return *sliced; - } - } - - return std::nullopt; -} - -IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { - for (IndexedBatchRunValue& run : indexedBatchRuns) { - if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) - continue; - for (const PackedScalarRunSlot& slot : run.slots) { - if (!llvm::is_contained(slot.keys, key)) - continue; - return &run; - } - } - return nullptr; -} - -std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { - - if (std::optional exact = lookupExact(key, classId)) { - return exact; - } - - if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) - return packedRunValue; - - MaterializedClass& materializedClass = state.classes[classId]; - - for (const auto& [candidateKey, classValues] : exactValues) { - if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key)) - continue; - - auto valueIt = classValues.find(classId); - if (valueIt == classValues.end()) - continue; - std::optional slice = - extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); - if (!slice) - return std::nullopt; - - record(key, classId, *slice); - return *slice; - } - return std::nullopt; -} - -Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { - SmallVector elements; - elements.reserve(values.size()); - for (int64_t value : values) - elements.push_back(APInt(64, value)); - - auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); - auto attr = DenseIntElementsAttr::get(type, elements); - return getOrCreateConstant(state.constantFolder, anchor, attr, type); -} - -bool allEqual(ArrayRef values) { - assert(!values.empty() && "expected at least one value"); - for (int64_t value : values.drop_front()) - if (value != values.front()) - return false; - return true; -} - -struct IndexedIndexPattern { - int64_t base = 0; - int64_t step = 0; - int64_t period = 1; - int64_t innerStep = 0; - int64_t outerStep = 0; - bool isTiled = false; -}; - -bool matchAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { - assert(!values.empty() && "expected at least one value"); - - pattern.base = values.front(); - pattern.step = values.size() == 1 ? 0 : values[1] - values[0]; - pattern.isTiled = false; - - for (auto [index, value] : llvm::enumerate(values)) { - int64_t expected = pattern.base + pattern.step * static_cast(index); - if (value != expected) - return false; - } - - return true; -} - -bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern, int64_t period) { - assert(!values.empty() && "expected at least one value"); - if (period < 2 || period > static_cast(values.size() / 2)) - return false; - - int64_t base = values.front(); - int64_t innerStep = values[1] - values[0]; - int64_t outerStep = values[period] - values[0]; - - for (auto [index, value] : llvm::enumerate(values)) { - int64_t i = static_cast(index); - int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); - if (value != expected) - return false; - } - - pattern.base = base; - pattern.period = period; - pattern.innerStep = innerStep; - pattern.outerStep = outerStep; - pattern.isTiled = true; - return true; -} - -bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { - assert(!values.empty() && "expected at least one value"); - - for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) - if (matchTiledAffineSequence(values, pattern, period)) - return true; - - return false; -} - -std::optional getIndexedIndexPattern(ArrayRef values, - std::optional preferredPeriod = std::nullopt, - bool allowExhaustiveTiledSearch = true) { - assert(!values.empty() && "expected at least one value"); - - IndexedIndexPattern pattern; - if (matchAffineSequence(values, pattern)) - return pattern; - if (preferredPeriod && matchTiledAffineSequence(values, pattern, *preferredPeriod)) - return pattern; - if (allowExhaustiveTiledSearch && values.size() <= 256 && matchTiledAffineSequence(values, pattern)) - return pattern; - - return std::nullopt; -} - -Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { - MLIRContext* context = state.func.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - - AffineExpr expr; - if (!pattern.isTiled) { - expr = getAffineConstantExpr(pattern.base, context) + d0 * pattern.step; - } - else { - expr = getAffineConstantExpr(pattern.base, context) + d0.floorDiv(pattern.period) * pattern.outerStep - + (d0 % pattern.period) * pattern.innerStep; - } - - AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func); -} - -Value createIndexedIndexValue(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc, - std::optional preferredPeriod, - bool allowExhaustiveTiledSearch) { - assert(!values.empty() && "expected at least one indexed value"); - - if (allEqual(values)) { - return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); - } - - if (std::optional pattern = - getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) - return createAffineIndexValue(state, *pattern, index, loc); - Value table = createIndexTensorConstant(state, anchor, values); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); -} - -Value createIndexedIndexValue( - MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { - assert(!values.empty() && "expected at least one indexed value"); - - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); -} - -OpFoldResult createIndexedOrStaticIndex(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc) { - assert(!values.empty() && "expected at least one indexed value"); - if (allEqual(values)) - return state.rewriter.getIndexAttr(values.front()); - return createIndexedIndexValue(state, anchor, values, index, loc); -} - -Value createIndexedChannelId( - MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { - return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); -} - -Value createIndexedChannelId(MaterializerState& state, - Operation* anchor, - const MessageVector& messages, - Value index, - Location loc, - std::optional preferredPeriod) { - return createIndexedIndexValue( - state, anchor, ArrayRef(messages.channelIds), index, loc, preferredPeriod, true); -} - -Value createIndexedSourceCoreId( - MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { - return createIndexedIndexValue(state, anchor, ArrayRef(messages.sourceCoreIds), index, loc); -} - -Value createIndexedSourceCoreId(MaterializerState& state, - Operation* anchor, - const MessageVector& messages, - Value index, - Location loc, - std::optional preferredPeriod) { - SmallVector widened(messages.sourceCoreIds.begin(), messages.sourceCoreIds.end()); - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); -} - -Value createIndexedTargetCoreId( - MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { - return createIndexedIndexValue(state, anchor, ArrayRef(messages.targetCoreIds), index, loc); -} - -Value createIndexedTargetCoreId(MaterializerState& state, - Operation* anchor, - const MessageVector& messages, - Value index, - Location loc, - std::optional preferredPeriod) { - SmallVector widened(messages.targetCoreIds.begin(), messages.targetCoreIds.end()); - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); -} - -Value createLaneIndexedIndexValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef values, - Location loc) { - assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); - assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); - - auto batch = cast(materializedClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected compute_batch lane argument"); - - return createIndexedIndexValue(state, materializedClass.op, values, *laneArg, loc); -} - -Value createLaneIndexedIndexValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef values, - Location loc) { - assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); - assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); - - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - - return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); -} - -FailureOr remapProjectionIndexLike(MaterializerState& state, - Operation* anchor, - OpFoldResult value, - Value sourceLaneArg, - Value mappedLaneValue, - Location loc) { - if (auto attr = dyn_cast(value)) - return value; - - Value operand = cast(value); - if (operand == sourceLaneArg) - return OpFoldResult(mappedLaneValue); - - if (matchPattern(operand, m_Constant())) - return getAsOpFoldResult(operand); - - auto affineApply = operand.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return failure(); - - SmallVector remappedOperands; - remappedOperands.reserve(affineApply.getMapOperands().size()); - for (Value mapOperand : affineApply.getMapOperands()) { - FailureOr remapped = - remapProjectionIndexLike(state, anchor, OpFoldResult(mapOperand), sourceLaneArg, mappedLaneValue, loc); - if (failed(remapped)) - return failure(); - remappedOperands.push_back(getValueOrCreateConstantIndexOp(state.rewriter, loc, *remapped)); - } - - return getAsOpFoldResult( - createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func)); -} - -FailureOr createProjectionLaneValueForKeys(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Location loc) { - if (!sourceClass.isBatch) - return sourceClass.op->emitError("projection lane mapping expects a batch materialized class"); - - auto batch = cast(sourceClass.op); - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("missing lane argument for projected batch host publication"); - - if (keys.size() == 1) { - if (keys.front().instance.laneCount != 1) - return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); - return getOrCreateIndexConstant(state.constantFolder, sourceClass.op, keys.front().instance.laneStart); - } - - if (keys.size() != sourceClass.cpus.size()) - return batch.emitOpError("projected batch host publication expected one producer key per materialized batch lane"); - - SmallVector sourceLanes; - sourceLanes.reserve(keys.size()); - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); - sourceLanes.push_back(key.instance.laneStart); - } - - return createIndexedIndexValue(state, sourceClass.op, sourceLanes, *laneArg, loc, std::nullopt, true); -} - -FailureOr> -getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { - SmallVector peers; - peers.reserve(materializedClass.cpus.size()); - for (CpuId cpu : materializedClass.cpus) { - auto streamIt = state.logicalInstancesByCpu.find(cpu); - if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) - return failure(); - peers.push_back(streamIt->second[logicalSlot]); - } - return peers; -} - -Value createOriginalLaneValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef peers, - Location loc) { - assert(!peers.empty() && "expected at least one peer instance"); - if (!materializedClass.isBatch) - return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart); - - auto batch = cast(materializedClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected materialized compute_batch lane argument"); - - SmallVector laneValues; - laneValues.reserve(peers.size()); - for (const ComputeInstance& peer : peers) - laneValues.push_back(peer.laneStart); - - return createIndexedIndexValue(state, materializedClass.op, ArrayRef(laneValues), *laneArg, loc); -} - -bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { - SmallVector worklist {value}; - DenseSet visited; - - while (!worklist.empty()) { - Value current = worklist.pop_back_val(); - if (!visited.insert(current).second) - continue; - - for (OpOperand& use : current.getUses()) { - Operation* owner = use.getOwner(); - if (isInsideOldCompute(owner, oldComputeOps)) - continue; - if (isa(owner)) { - for (Value result : owner->getResults()) - worklist.push_back(result); - continue; - } - return true; - } - } - - return false; -} - -bool hasRealComputeConsumer(Value value, const DenseSet& oldComputeOps) { - SmallVector worklist {value}; - DenseSet visited; - - while (!worklist.empty()) { - Value current = worklist.pop_back_val(); - if (!visited.insert(current).second) - continue; - - for (OpOperand& use : current.getUses()) { - Operation* owner = use.getOwner(); - if (isInsideOldCompute(owner, oldComputeOps)) - continue; - if (isa(owner)) { - for (Value result : owner->getResults()) - worklist.push_back(result); - continue; - } - if (isa(owner)) - continue; - return true; - } - } - - return false; -} - -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); - -bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps) { - auto batch = dyn_cast_or_null(output.getDefiningOp()); - if (!batch || batch.getNumResults() == 0) - return false; - if (!hasLiveExternalUse(output, oldComputeOps)) - return false; - return !hasRealComputeConsumer(output, oldComputeOps); -} - -bool isProjectedTerminalBatchHostOutput(Value output, const DenseSet& oldComputeOps) { - if (!isTerminalHostBatchOutput(output, oldComputeOps)) - return false; - - auto batch = dyn_cast_or_null(output.getDefiningOp()); - auto originalResult = dyn_cast(output); - if (!batch || !originalResult) - return false; - - FailureOr projection = - getBatchResultProjectionInsert(batch, originalResult.getResultNumber()); - if (failed(projection)) - return false; - - return projection->getSource().getType() != output.getType(); -} - -LogicalResult emitBatchToScalarDestinationDiagnostic(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value originalOutput) { - auto diag = sourceClass.op->emitError("resultful compute_batch output would enter batch-to-scalar class fanout"); - diag << " sourceClassId=" << sourceClass.id << " sourceKind=" << (sourceClass.isBatch ? "batch" : "scalar"); - diag << " liveExternalUse=" << (hasLiveExternalUseCached(state, originalOutput) ? "true" : "false"); - diag << " terminalHostBatch=" << (isTerminalHostBatchOutput(originalOutput, state.oldComputeOps) ? "true" : "false"); - diag << " originalDef=" - << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() : StringRef("")); - - bool first = true; - diag << " destinationClasses=["; - auto destIt = state.producerDestClasses.find(keys.front()); - ArrayRef destinations = destIt == state.producerDestClasses.end() ? ArrayRef {} : ArrayRef(destIt->second); - for (ClassId classId : destinations) { - if (!first) - diag << ", "; - first = false; - const MaterializedClass& destClass = state.classes[classId]; - diag << classId << ":" << (destClass.isBatch ? "batch" : "scalar"); - } - diag << "]"; - - diag << " producerKeys=["; - first = true; - for (ProducerKey key : keys) { - if (!first) - diag << ", "; - first = false; - diag << key.instance.op->getName().getStringRef() << ":r" << key.resultIndex << ":laneStart=" << key.instance.laneStart - << ":laneCount=" << key.instance.laneCount; - } - diag << "]"; - return failure(); -} - -void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { - SmallVector& destinations = state.producerDestClasses[key]; - if (!llvm::is_contained(destinations, classId)) - destinations.push_back(classId); -} - -void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet& oldComputeOps) { - SmallVector uses; - for (OpOperand& use : oldValue.getUses()) - uses.push_back(&use); - - for (OpOperand* use : uses) { - Operation* owner = use->getOwner(); - if (isInsideOldCompute(owner, oldComputeOps)) - continue; - use->set(replacement); - } -} - -LogicalResult collectProducerDestinations(MaterializerState& state) { - return forEachLogicalConsumerInMaterializationOrder( - state, - [&](CpuId, ClassId targetClass, ComputeInstance scheduledConsumer, ComputeInstance logicalConsumer, SlotId) - -> LogicalResult { - for (Value input : getComputeInstanceInputs(scheduledConsumer)) { - for (ProducerKey producerKey : collectProducerKeysForDestinations(input, logicalConsumer)) { - ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producerKey.instance); - auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); - if (producerCpuIt == state.schedule.computeToCpuMap.end()) - return logicalConsumer.op->emitError( - "schedule materialization found an input produced by an unscheduled compute"); - - ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClass == targetClass) { - SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass}; - SmallVector& bucket = state.sameClassConsumerIndex[lookupKey]; - if (!llvm::is_contained(bucket, producerKey)) - bucket.push_back(producerKey); - continue; - } - - appendDestinationClass(state, producerKey, targetClass); - } - } - - return success(); - }); -} - -bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceType, RankedTensorType fragmentType) { - if (offsets.size() != static_cast(sourceType.getRank()) - || offsets.size() != static_cast(fragmentType.getRank())) - return false; - - for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - int64_t offset = offsets[dim]; - if (offset < 0) - return false; - - int64_t sourceDimSize = sourceType.getDimSize(dim); - int64_t fragmentDimSize = fragmentType.getDimSize(dim); - if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) - return false; - if (offset > sourceDimSize - fragmentDimSize) - return false; - } - - return true; -} - - -bool isStaticSliceContainedIn(ArrayRef innerOffsets, - ArrayRef innerSizes, - ArrayRef outerOffsets, - ArrayRef outerSizes) { - if (innerOffsets.size() != innerSizes.size() || outerOffsets.size() != outerSizes.size() - || innerOffsets.size() != outerOffsets.size()) - return false; - - for (size_t dim = 0; dim < innerOffsets.size(); ++dim) { - if (innerSizes[dim] < 0 || outerSizes[dim] < 0) - return false; - - int64_t innerBegin = innerOffsets[dim]; - int64_t innerEnd = innerBegin + innerSizes[dim]; - int64_t outerBegin = outerOffsets[dim]; - int64_t outerEnd = outerBegin + outerSizes[dim]; - if (innerBegin < outerBegin || innerEnd > outerEnd) - return false; - } - - return true; -} - -bool areAllUnitStrides(ArrayRef strides) { - return llvm::all_of(strides, [](int64_t stride) { return stride == 1; }); -} - -static std::optional getStaticForTripCount(scf::ForOp loop) { - std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); - std::optional upperBound = matchConstantIndexValue(loop.getUpperBound()); - std::optional step = matchConstantIndexValue(loop.getStep()); - if (!lowerBound || !upperBound || !step || *step <= 0 || *upperBound < *lowerBound) - return std::nullopt; - - int64_t distance = *upperBound - *lowerBound; - return (distance + *step - 1) / *step; -} - -static SmallVector collectEnclosingStaticProjectedLoops(Operation* op) { - SmallVector loops; - SmallVector reversedLoops; - for (Operation* current = op->getParentOp(); current; current = current->getParentOp()) - if (auto loop = dyn_cast(current)) - reversedLoops.push_back(loop); - - for (scf::ForOp loop : llvm::reverse(reversedLoops)) { - std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); - std::optional step = matchConstantIndexValue(loop.getStep()); - std::optional tripCount = getStaticForTripCount(loop); - if (!lowerBound || !step || !tripCount) - return {}; - loops.push_back(StaticProjectedLoopInfo {.iv = cast(loop.getInductionVar()), - .lowerBound = *lowerBound, - .step = *step, - .tripCount = *tripCount}); - } - return loops; -} - -static bool -isProjectedOffsetValue(Value value, Value laneArg, ArrayRef loops, bool& usesDynamicBinding) { - if (value == laneArg) { - usesDynamicBinding = true; - return true; - } - - for (const StaticProjectedLoopInfo& loop : loops) { - if (value == loop.iv) { - usesDynamicBinding = true; - return true; - } - } - - if (matchPattern(value, m_Constant())) - return true; - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return false; - - bool nestedUsesDynamicBinding = false; - for (Value operand : affineApply.getMapOperands()) { - bool operandUsesDynamicBinding = false; - if (!isProjectedOffsetValue(operand, laneArg, loops, operandUsesDynamicBinding)) - return false; - nestedUsesDynamicBinding = nestedUsesDynamicBinding || operandUsesDynamicBinding; - } - - usesDynamicBinding = usesDynamicBinding || nestedUsesDynamicBinding; - return true; -} - -static std::optional getConstantIndex(OpFoldResult value); - -static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef loopTripCounts) { - unsigned fragmentsPerLogicalSlot = 1; - for (int64_t tripCount : loopTripCounts) { - assert(tripCount > 0 && "projected loop trip counts must be positive"); - fragmentsPerLogicalSlot *= static_cast(tripCount); - } - return fragmentsPerLogicalSlot; -} - -LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) { - if (!layout.fragmentType || layout.fragmentShape.empty()) - return anchor->emitError("projected fragment layout is missing fragment type metadata"); - if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) - return anchor->emitError("projected fragment layout rank does not match fragment type"); - if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) - return anchor->emitError("projected fragment layout has an invalid fragment count"); - if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) - return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); - return success(); -} - -FailureOr -getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) { - if (failed( - verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type"))) - return failure(); - return getPackedBatchTensorType(fragmentType, payloadFragmentCount); -} - -SmallVector, 4> -buildProjectedFragmentOffsetsByDim(ArrayRef> fragmentOffsets, size_t rank) { - SmallVector, 4> fragmentOffsetsByDim(rank); - for (ArrayRef offsets : fragmentOffsets) { - assert(offsets.size() == rank && "projected offset rank mismatch"); - for (size_t dim = 0; dim < rank; ++dim) - fragmentOffsetsByDim[dim].push_back(offsets[dim]); - } - return fragmentOffsetsByDim; -} - -LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) { - if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) - return failure(); - if (!descriptor.payloadType) - return anchor->emitError("projected transfer descriptor is missing payload type"); - if (descriptor.fragmentOffsets.empty()) - return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); - if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) - return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); - for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) - if (dimOffsets.size() != descriptor.fragmentOffsets.size()) - return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); - for (ArrayRef offsets : descriptor.fragmentOffsets) - if (offsets.size() != descriptor.layout.fragmentShape.size()) - return anchor->emitError("projected transfer offset rank does not match fragment rank"); - return success(); -} - -LogicalResult verifyProjectedSendDescriptor(Operation* anchor, - const ProjectedTransferDescriptor& descriptor, - const MessageVector& messages) { - if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) - return failure(); - if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) - return anchor->emitError("projected send descriptor metadata is inconsistent"); - return success(); -} - -LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) { - descriptor.fragmentOffsetsByDim = - buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); - - FailureOr payloadType = - getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); - if (failed(payloadType)) - return failure(); - if (descriptor.payloadType && descriptor.payloadType != *payloadType) - return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); - descriptor.payloadType = *payloadType; - - return verifyProjectedTransferDescriptor(anchor, descriptor); -} - -static FailureOr evaluateProjectedOffsetValue(OpFoldResult value, - Value laneArg, - uint32_t lane, - ArrayRef loops, - ArrayRef loopIterationIndices) { - if (std::optional constant = getConstantIndex(value)) - return *constant; - - Value current = dyn_cast(value); - if (!current) - return failure(); - if (current == laneArg) - return static_cast(lane); - - for (auto [index, loop] : llvm::enumerate(loops)) { - if (current != loop.iv) - continue; - if (index >= loopIterationIndices.size()) - return failure(); - return loop.lowerBound + loopIterationIndices[index] * loop.step; - } - - if (auto affineApply = current.getDefiningOp()) { - return evaluateAffineApply(affineApply, [&](Value operand) { - return evaluateProjectedOffsetValue(operand, laneArg, lane, loops, loopIterationIndices); - }); - } - - return failure(); -} - -static std::optional getConstantIndex(OpFoldResult value) { - if (auto attr = dyn_cast(value)) { - auto intAttr = dyn_cast(attr); - if (!intAttr) - return std::nullopt; - return intAttr.getInt(); - } - - Value operand = dyn_cast(value); - if (!operand) - return std::nullopt; - - if (auto constantIndex = operand.getDefiningOp()) - return constantIndex.value(); - - APInt apInt; - if (matchPattern(operand, m_ConstantInt(&apInt))) { - if (apInt.isNegative()) - return std::nullopt; - return static_cast(apInt.getSExtValue()); - } - - return std::nullopt; -} - -static std::optional matchAffineProjectedInputSlice(SpatComputeBatch batch, - unsigned inputIndex) { - const auto fail = [&](StringRef) -> std::optional { return std::nullopt; }; - - std::optional inputArg = batch.getInputArgument(inputIndex); - std::optional laneArg = batch.getLaneArgument(); - if (!inputArg || !laneArg) - return fail("missing-input-or-lane-arg"); - - if (!inputArg->hasOneUse()) - return fail("input-arg-not-one-use"); - - Operation* user = *inputArg->getUsers().begin(); - auto extract = dyn_cast(user); - if (!extract || extract.getSource() != *inputArg) - return fail("input-user-is-not-direct-extract-slice"); - - auto inputType = dyn_cast(inputArg->getType()); - auto fragmentType = dyn_cast(extract.getResult().getType()); - if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape()) - return fail("non-static-ranked-input-or-fragment"); - - if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank()) - return fail("rank-mismatch-or-rank-zero"); - - SmallVector offsets = extract.getMixedOffsets(); - SmallVector sizes = extract.getMixedSizes(); - SmallVector strides = extract.getMixedStrides(); - - if (offsets.size() != static_cast(inputType.getRank()) - || sizes.size() != static_cast(inputType.getRank()) - || strides.size() != static_cast(inputType.getRank())) - return fail("slice-rank-mismatch"); - - SmallVector loops = collectEnclosingStaticProjectedLoops(extract.getOperation()); - if (extract->getParentOfType() && loops.empty()) - return fail("unsupported-enclosing-loop"); - - bool hasDynamicProjection = false; - for (auto [dim, offset] : llvm::enumerate(offsets)) { - bool usesDynamicBinding = false; - if (auto value = dyn_cast(offset)) { - if (!isProjectedOffsetValue(value, *laneArg, loops, usesDynamicBinding)) - return std::nullopt; - } - else if (!isa(offset)) - return std::nullopt; - if (std::optional stride = getConstantIndex(strides[dim]); !stride || *stride != 1) - return std::nullopt; - std::optional size = getConstantIndex(sizes[dim]); - if (!size || *size != fragmentType.getDimSize(dim)) - return std::nullopt; - hasDynamicProjection = hasDynamicProjection || usesDynamicBinding; - } - - if (!hasDynamicProjection) - return fail("no-dynamic-projection"); - - for (int64_t dim = 0; dim < inputType.getRank(); ++dim) - if (fragmentType.getDimSize(dim) <= 0 || fragmentType.getDimSize(dim) > inputType.getDimSize(dim)) - return std::nullopt; - - AffineProjectedInputSliceMatch match; - match.extract = extract; - match.sourceType = inputType; - match.fragmentType = fragmentType; - match.offsets.assign(offsets.begin(), offsets.end()); - match.fragmentShape.assign(fragmentType.getShape().begin(), fragmentType.getShape().end()); - match.loops = std::move(loops); - return match; -} - -std::optional -getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex) { - ProjectedBatchInputKey key {batch.getOperation(), inputIndex}; - auto cached = state.projectedInputMatches.find(key); - if (cached != state.projectedInputMatches.end()) - return cached->second; - if (state.nonProjectedInputs.contains(key)) - return std::nullopt; - - std::optional match = matchAffineProjectedInputSlice(batch, inputIndex); - if (!match) { - state.nonProjectedInputs.insert(key); - return std::nullopt; - } - - state.projectedInputMatches.insert({key, *match}); - return match; -} - -LogicalResult collectProjectedTransfers(MaterializerState& state) { - struct PendingProjectedTransferDescriptor { - ProjectedBatchInputKey inputKey; - Operation* extractOp = nullptr; - RankedTensorType sourceType; - RankedTensorType fragmentType; - SmallVector fragmentShape; - SmallVector, 16>, 8> fragmentOffsetsByLane; - SmallVector loopLowerBounds; - SmallVector loopSteps; - SmallVector loopTripCounts; - bool invalid = false; - }; - - DenseMap, ProducerKeyInfo> pending; - - const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) { - if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType) - return false; - - if (descriptor.fragmentOffsetsByLane.size() != 1) - return false; - - ArrayRef> fragments = descriptor.fragmentOffsetsByLane.front(); - if (fragments.size() != 1) - return false; - - return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; }); - }; - - const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor, - unsigned targetLane, - const AffineProjectedInputSliceMatch& match, - Value laneArg, - uint32_t lane) -> LogicalResult { - SmallVector loopIterationIndices; - loopIterationIndices.resize(match.loops.size(), 0); - - const auto appendOneFragment = [&]() -> LogicalResult { - SmallVector evaluatedOffsets; - evaluatedOffsets.reserve(match.offsets.size()); - for (OpFoldResult offset : match.offsets) { - FailureOr evaluated = - evaluateProjectedOffsetValue(offset, laneArg, lane, match.loops, loopIterationIndices); - if (failed(evaluated)) - return failure(); - evaluatedOffsets.push_back(*evaluated); - } - - if (!isStaticSliceInBounds(evaluatedOffsets, match.sourceType, match.fragmentType)) - return failure(); - - descriptor.fragmentOffsetsByLane[targetLane].push_back(std::move(evaluatedOffsets)); - return success(); - }; - - if (match.loops.empty()) - return appendOneFragment(); - - const auto recurse = [&](auto&& self, size_t loopIndex) -> LogicalResult { - if (loopIndex == match.loops.size()) - return appendOneFragment(); - - for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { - loopIterationIndices[loopIndex] = iteration; - if (failed(self(self, loopIndex + 1))) - return failure(); - } - return success(); - }; - - return recurse(recurse, 0); - }; - - if (failed(forEachLogicalConsumerInMaterializationOrder( - state, - [&](CpuId cpu, - ClassId targetClassId, - ComputeInstance consumer, - ComputeInstance logicalConsumer, - SlotId logicalSlot) -> LogicalResult { - auto batch = dyn_cast(consumer.op); - if (!batch) - return success(); - - MaterializedClass& targetClass = state.classes[targetClassId]; - unsigned targetLane = 0; - if (targetClass.isBatch) { - auto targetLaneIt = targetClass.cpuToLane.find(cpu); - if (targetLaneIt == targetClass.cpuToLane.end()) - return consumer.op->emitError("projected transfer collection could not recover target lane"); - targetLane = targetLaneIt->second; - } - - for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { - SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); - if (producers.size() != 1) - continue; - ProducerKey producer = producers.front(); - - ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producer.instance); - auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); - if (producerCpuIt == state.schedule.computeToCpuMap.end()) - continue; - - ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClassId == targetClassId) - continue; - - std::optional match = - getProjectedInputSliceMatch(state, batch, static_cast(inputIndex)); - if (!match) - continue; - if (!isProjectedInputSliceCompatibleWithProducerFragments( - batch, *match, producer, logicalConsumer.laneStart)) - continue; - - PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; - if (descriptor.fragmentOffsetsByLane.empty()) { - descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; - descriptor.extractOp = match->extract.getOperation(); - descriptor.sourceType = match->sourceType; - descriptor.fragmentType = match->fragmentType; - descriptor.fragmentShape = match->fragmentShape; - descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1); - descriptor.loopLowerBounds.reserve(match->loops.size()); - descriptor.loopSteps.reserve(match->loops.size()); - descriptor.loopTripCounts.reserve(match->loops.size()); - for (const StaticProjectedLoopInfo& loop : match->loops) { - descriptor.loopLowerBounds.push_back(loop.lowerBound); - descriptor.loopSteps.push_back(loop.step); - descriptor.loopTripCounts.push_back(loop.tripCount); - } - } - - ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; - if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation() - || descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType - || descriptor.fragmentShape != match->fragmentShape - || descriptor.loopLowerBounds.size() != match->loops.size()) { - descriptor.invalid = true; - continue; - } - for (auto [index, loop] : llvm::enumerate(match->loops)) { - if (descriptor.loopLowerBounds[index] != loop.lowerBound || descriptor.loopSteps[index] != loop.step - || descriptor.loopTripCounts[index] != loop.tripCount) { - descriptor.invalid = true; - break; - } - } - if (descriptor.invalid) - continue; - - if (targetLane >= descriptor.fragmentOffsetsByLane.size()) { - descriptor.invalid = true; - continue; - } - - if (failed(appendEvaluatedFragments( - descriptor, targetLane, *match, *batch.getLaneArgument(), logicalConsumer.laneStart))) { - descriptor.invalid = true; - continue; - } - - (void) logicalSlot; - } - - return success(); - }))) - return failure(); - - for (auto& producerEntry : pending) { - ProducerKey producer = producerEntry.first; - for (auto& classEntry : producerEntry.second) { - ClassId targetClassId = classEntry.first; - PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second; - - if (pendingDescriptor.invalid) - continue; - if (pendingDescriptor.fragmentOffsetsByLane.empty()) - continue; - if (isIdentityProjectedTransfer(pendingDescriptor)) - continue; - - MaterializedClass& targetClass = state.classes[targetClassId]; - ProjectedTransferDescriptor descriptor; - descriptor.inputKey = pendingDescriptor.inputKey; - descriptor.extractOp = pendingDescriptor.extractOp; - descriptor.layout.fragmentType = pendingDescriptor.fragmentType; - descriptor.layout.fragmentShape = pendingDescriptor.fragmentShape; - descriptor.layout.loopLowerBounds = pendingDescriptor.loopLowerBounds; - descriptor.layout.loopSteps = pendingDescriptor.loopSteps; - descriptor.layout.loopTripCounts = pendingDescriptor.loopTripCounts; - descriptor.layout.fragmentsPerLogicalSlot = getProjectedFragmentsPerLogicalSlot(descriptor.layout.loopTripCounts); - if (targetClass.isBatch) { - unsigned payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); - if (payloadFragmentCount == 0) - continue; - - // Batch-target projected replacements currently select fragments with the - // local materialization-run slot index. That is only unambiguous when each - // target lane receives one projected fragment. Multi-fragment payloads - // need an explicit producer-key to payload-slot mapping; otherwise two - // independently materialized runs can both select fragment 0 from the same - // packed receive and duplicate rows. - if (payloadFragmentCount != 1) - continue; - - bool uniform = true; - for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) { - if (laneFragments.size() != payloadFragmentCount) { - uniform = false; - break; - } - } - if (!uniform) - continue; - - descriptor.layout.payloadFragmentCount = payloadFragmentCount; - descriptor.fragmentOffsets.reserve(pendingDescriptor.fragmentOffsetsByLane.size() * payloadFragmentCount); - for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) - llvm::append_range(descriptor.fragmentOffsets, laneFragments); - } - else { - if (pendingDescriptor.fragmentOffsetsByLane.size() != 1) - return targetClass.op->emitError("scalar projected transfer descriptor expected one local offset stream"); - if (pendingDescriptor.fragmentOffsetsByLane.front().empty()) - continue; - - descriptor.layout.payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); - llvm::append_range(descriptor.fragmentOffsets, pendingDescriptor.fragmentOffsetsByLane.front()); - if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) - return targetClass.op->emitError("scalar projected transfer offset count does not match the local run"); - } - if (failed(finalizeProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - - state.projectedTransfers[producer][targetClassId] = std::move(descriptor); - } - } - - return success(); -} - -static std::optional -collectScalarTargetProjectedDescriptor(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef keys, - bool requirePackedRunOffsetCountMatch) { - assert(!targetClass.isBatch && "scalar target projected descriptor helper expects a scalar class"); - - std::optional combined; - for (ProducerKey key : keys) { - auto producerIt = state.projectedTransfers.find(key); - if (producerIt == state.projectedTransfers.end()) - return std::nullopt; - - auto descriptorIt = producerIt->second.find(targetClass.id); - if (descriptorIt == producerIt->second.end()) - return std::nullopt; - - const ProjectedTransferDescriptor& descriptor = descriptorIt->second; - if (descriptor.fragmentOffsets.empty()) - return std::nullopt; - if (descriptor.layout.payloadFragmentCount == 0 || descriptor.layout.fragmentsPerLogicalSlot == 0) - return std::nullopt; - if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) - return std::nullopt; - if (descriptor.layout.payloadFragmentCount % descriptor.layout.fragmentsPerLogicalSlot != 0) - return std::nullopt; - - if (!combined) { - combined = descriptor; - continue; - } - - if (!(combined->inputKey == descriptor.inputKey) || combined->extractOp != descriptor.extractOp - || combined->layout.fragmentType != descriptor.layout.fragmentType - || combined->layout.fragmentShape != descriptor.layout.fragmentShape - || combined->layout.loopLowerBounds != descriptor.layout.loopLowerBounds - || combined->layout.loopSteps != descriptor.layout.loopSteps - || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts - || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot) - return std::nullopt; - - combined->layout.payloadFragmentCount += descriptor.layout.payloadFragmentCount; - llvm::append_range(combined->fragmentOffsets, descriptor.fragmentOffsets); - } - - if (!combined) - return std::nullopt; - - if (combined->fragmentOffsets.size() != combined->layout.payloadFragmentCount) - return std::nullopt; - - if (requirePackedRunOffsetCountMatch) { - if (combined->layout.payloadFragmentCount != keys.size() * combined->layout.fragmentsPerLogicalSlot) - return std::nullopt; - } - if (failed(finalizeProjectedTransferDescriptor(targetClass.op, *combined))) - return std::nullopt; - return combined; -} - -bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { - if (keys.empty()) - return true; - - auto firstIt = state.producerDestClasses.find(keys.front()); - ArrayRef first = firstIt == state.producerDestClasses.end() ? ArrayRef() : firstIt->second; - for (ProducerKey key : keys.drop_front()) { - auto it = state.producerDestClasses.find(key); - ArrayRef current = it == state.producerDestClasses.end() ? ArrayRef() : it->second; - if (first.size() != current.size()) - return false; - for (auto [lhs, rhs] : llvm::zip(first, current)) - if (lhs != rhs) - return false; - } - return true; -} - -ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey key) { - auto it = state.producerDestClasses.find(key); - if (it == state.producerDestClasses.end()) - return {}; - return it->second; -} - -std::optional getKnownMinimumIndexValue(Value value) { - if (std::optional constant = matchConstantIndexValue(value)) - return *constant; - - if (auto blockArg = dyn_cast(value)) { - if (blockArg.getArgNumber() == 0) { - if (auto loop = dyn_cast_or_null(blockArg.getOwner()->getParentOp())) - return matchConstantIndexValue(loop.getLowerBound()); - } - return std::nullopt; - } - - if (auto add = value.getDefiningOp()) { - std::optional lhs = getKnownMinimumIndexValue(add.getLhs()); - std::optional rhs = getKnownMinimumIndexValue(add.getRhs()); - if (lhs && rhs) - return *lhs + *rhs; - return std::nullopt; - } - - if (auto mul = value.getDefiningOp()) { - std::optional lhs = getKnownMinimumIndexValue(mul.getLhs()); - std::optional rhs = getKnownMinimumIndexValue(mul.getRhs()); - if (!lhs || !rhs) - return std::nullopt; - if (*lhs >= 0 && *rhs >= 0) - return *lhs * *rhs; - return std::nullopt; - } - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return std::nullopt; - - SmallVector operands; - operands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - std::optional minimum = getKnownMinimumIndexValue(operand); - if (!minimum) - return std::nullopt; - operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *minimum)); - } - - SmallVector results; - if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) - return std::nullopt; - - auto intAttr = dyn_cast(results.front()); - if (!intAttr) - return std::nullopt; - return intAttr.getInt(); -} - -std::optional getKnownMinimumCommunicationChannelId(Operation* op) { - if (auto send = dyn_cast(op)) - return getKnownMinimumIndexValue(send.getChannelId()); - if (auto receive = dyn_cast(op)) - return getKnownMinimumIndexValue(receive.getChannelId()); - - std::optional minimum; - op->walk([&](Operation* nested) { - if (nested == op) - return; - std::optional nestedMinimum = getKnownMinimumCommunicationChannelId(nested); - if (!nestedMinimum) - return; - if (!minimum || *nestedMinimum < *minimum) - minimum = *nestedMinimum; - }); - return minimum; -} - -void setInsertionPointForScalarReceive(MaterializerState& state, - MaterializedClass& targetClass, - int64_t channelId) { - assert(!targetClass.isBatch && "scalar receive ordering expects a scalar target class"); - - for (Operation& op : *targetClass.body) { - if (op.hasTrait()) - break; - - std::optional existingChannel = getKnownMinimumCommunicationChannelId(&op); - if (existingChannel && *existingChannel > channelId) { - state.rewriter.setInsertionPoint(&op); - return; - } - } - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); -} - -// ----------------------------------------------------------------------------- -// Communication materialization helpers. -// ----------------------------------------------------------------------------- - -constexpr const char* kRaptorMinChannelIdAttr = "raptor.min_channel_id"; -constexpr const char* kRaptorMaterializerAttr = "raptor.materializer"; -constexpr const char* kRaptorCommTraceIdAttr = "raptor.comm_trace_id"; -constexpr const char* kRaptorCommTraceKindAttr = "raptor.comm_trace_kind"; -constexpr const char* kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase"; -constexpr const char* kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id"; -constexpr const char* kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind"; -constexpr const char* kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal"; -constexpr const char* kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload"; -constexpr const char* kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages"; -constexpr const char* kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op"; -constexpr const char* kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op"; - -int64_t getMinimumChannelId(ArrayRef channelIds) { - assert(!channelIds.empty() && "expected at least one channel id"); - int64_t minChannelId = channelIds.front(); - for (int64_t channelId : channelIds.drop_front()) - if (channelId < minChannelId) - minChannelId = channelId; - return minChannelId; -} - -SmallVector getScalarSendChannelOrder(const MessageVector& messages) { - SmallVector order; - order.reserve(messages.size()); - for (size_t i = 0, e = messages.size(); i < e; ++i) - order.push_back(i); - - llvm::sort(order, [&](size_t lhs, size_t rhs) { - if (messages.channelIds[lhs] != messages.channelIds[rhs]) - return messages.channelIds[lhs] < messages.channelIds[rhs]; - if (messages.sourceCoreIds[lhs] != messages.sourceCoreIds[rhs]) - return messages.sourceCoreIds[lhs] < messages.sourceCoreIds[rhs]; - return messages.targetCoreIds[lhs] < messages.targetCoreIds[rhs]; - }); - return order; -} - -MessageVector reorderMessages(const MessageVector& messages, ArrayRef order) { - MessageVector reordered; - reordered.channelIds.reserve(messages.size()); - reordered.sourceCoreIds.reserve(messages.size()); - reordered.targetCoreIds.reserve(messages.size()); - for (size_t index : order) - reordered.append(messages.channelIds[index], messages.sourceCoreIds[index], messages.targetCoreIds[index]); - return reordered; -} - -MessageVector reorderScalarSendMessagesByChannel(const MessageVector& messages) { - return reorderMessages(messages, getScalarSendChannelOrder(messages)); -} - -ProjectedTransferDescriptor reorderProjectedDescriptorByMessageOrder(const ProjectedTransferDescriptor& descriptor, - ArrayRef order) { - ProjectedTransferDescriptor reordered = descriptor; - size_t payloadFragmentCount = static_cast(descriptor.layout.payloadFragmentCount); - reordered.fragmentOffsets.clear(); - reordered.fragmentOffsets.reserve(descriptor.fragmentOffsets.size()); - for (size_t messageIndex : order) { - size_t offset = messageIndex * payloadFragmentCount; - for (size_t fragmentIndex = 0; fragmentIndex < payloadFragmentCount; ++fragmentIndex) - reordered.fragmentOffsets.push_back(descriptor.fragmentOffsets[offset + fragmentIndex]); - } - reordered.fragmentOffsetsByDim.clear(); - return reordered; -} - - -Operation* getPayloadDefiningOpInClassBlock(Value payload, MaterializedClass& materializedClass) { - Operation* definingOp = payload.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != materializedClass.body) - return nullptr; - return definingOp; -} - -Operation* findScalarCommunicationInsertionPoint(MaterializedClass& materializedClass, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - Operation* terminator = materializedClass.body->getTerminator(); - bool afterLowerBound = lowerBound == nullptr; - - for (Operation& op : *materializedClass.body) { - if (&op == terminator) - break; - - if (!afterLowerBound) { - if (&op == lowerBound) - afterLowerBound = true; - continue; - } - - if (&op == lowerBound) - continue; - - auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); - if (existingMinChannel && existingMinChannel.getInt() > minChannelId) - return &op; - } - - return terminator; -} - -void setInsertionPointForScalarCommunication(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - state.rewriter.setInsertionPoint( - findScalarCommunicationInsertionPoint(materializedClass, minChannelId, lowerBound)); -} - -constexpr const char kRaptorCommOrderAttr[] = "raptor.comm_order"; - -int64_t computeBlockingCommunicationOrderKey(int32_t sourceCoreId, int32_t targetCoreId, int64_t channelId) { - int64_t lowCore = std::min(sourceCoreId, targetCoreId); - int64_t highCore = std::max(sourceCoreId, targetCoreId); - int64_t directionPhase = sourceCoreId <= targetCoreId ? 0 : 1; - return (((lowCore * 1000000LL + highCore) * 2LL + directionPhase) * 1000000000LL) + channelId; -} - -int64_t getMinimumBlockingCommunicationOrderKey(const MessageVector& messages) { - assert(!messages.empty() && "expected at least one message"); - int64_t best = computeBlockingCommunicationOrderKey( - messages.sourceCoreIds.front(), messages.targetCoreIds.front(), messages.channelIds.front()); - for (size_t index = 1, end = messages.size(); index < end; ++index) { - best = std::min(best, computeBlockingCommunicationOrderKey( - messages.sourceCoreIds[index], messages.targetCoreIds[index], messages.channelIds[index])); - } - return best; -} - -Operation* findScalarCommunicationInsertionPointByOrder(MaterializedClass& materializedClass, - int64_t orderKey, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - Operation* terminator = materializedClass.body->getTerminator(); - bool afterLowerBound = lowerBound == nullptr; - - for (Operation& op : *materializedClass.body) { - if (&op == terminator) - break; - - if (!afterLowerBound) { - if (&op == lowerBound) - afterLowerBound = true; - continue; - } - - if (&op == lowerBound) - continue; - - if (auto existingOrder = op.getAttrOfType(kRaptorCommOrderAttr)) { - if (existingOrder.getInt() > orderKey) - return &op; - continue; - } - - auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); - if (existingMinChannel && existingMinChannel.getInt() > minChannelId) - return &op; - } - - return terminator; -} - -void setInsertionPointForScalarCommunicationOrder(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t orderKey, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - if (!pimMaterializeScalarFanoutGlobalOrder) { - setInsertionPointForScalarCommunication(state, materializedClass, minChannelId, lowerBound); - return; - } - - state.rewriter.setInsertionPoint( - findScalarCommunicationInsertionPointByOrder(materializedClass, orderKey, minChannelId, lowerBound)); -} - -void markScalarCommunication(Operation* op, int64_t minChannelId, StringRef materializer = StringRef()) { - if (!op) - return; - op->setAttr(kRaptorMinChannelIdAttr, - IntegerAttr::get(IndexType::get(op->getContext()), minChannelId)); - if (!materializer.empty()) - op->setAttr(kRaptorMaterializerAttr, StringAttr::get(op->getContext(), materializer)); -} - -void markScalarCommunicationOrder(Operation* op, int64_t orderKey) { - if (!op) - return; - op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(IndexType::get(op->getContext()), orderKey)); -} - -std::optional getOperationOrdinalInBlock(Operation* op) { - if (!op || !op->getBlock()) - return std::nullopt; - - int64_t ordinal = 0; - for (Operation& candidate : *op->getBlock()) { - if (&candidate == op) - return ordinal; - ++ordinal; - } - return std::nullopt; -} - -std::string formatOperationForTrace(Operation* op) { - if (!op) - return ""; - - std::string text; - llvm::raw_string_ostream os(text); - os << op->getName().getStringRef(); - if (auto ordinal = getOperationOrdinalInBlock(op)) - os << "@" << *ordinal; - return os.str(); -} - -std::string formatValueForTrace(Value value, Block* localBody) { - if (!value) - return ""; - - std::string text; - llvm::raw_string_ostream os(text); - if (auto arg = dyn_cast(value)) { - os << "block_arg#" << arg.getArgNumber(); - return os.str(); - } - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) { - os << "external"; - return os.str(); - } - - os << definingOp->getName().getStringRef(); - if (definingOp->getBlock() == localBody) { - if (auto ordinal = getOperationOrdinalInBlock(definingOp)) - os << "@" << *ordinal; - } - else { - os << "@external-block"; - } - return os.str(); -} - -std::string formatClassForTrace(const MaterializedClass& materializedClass) { - std::string text; - llvm::raw_string_ostream os(text); - os << (materializedClass.isBatch ? "batch" : "scalar") << " class " << materializedClass.id << " cpus=["; - for (auto [index, cpu] : llvm::enumerate(materializedClass.cpus)) { - if (index) - os << ","; - os << cpu; - } - os << "]"; - return os.str(); -} - -std::string formatMessagesForTrace(const MessageVector& messages, unsigned maxMessages = 8) { - std::string text; - llvm::raw_string_ostream os(text); - os << "count=" << messages.size() << " ["; - unsigned limit = std::min(maxMessages, messages.size()); - for (unsigned index = 0; index < limit; ++index) { - if (index) - os << "; "; - os << "c" << messages.channelIds[index] << ":" << messages.sourceCoreIds[index] - << "->" << messages.targetCoreIds[index]; - } - if (messages.size() > limit) - os << "; ..."; - os << "]"; - return os.str(); -} - -void annotateCommunicationMaterialization(MaterializerState& state, - MaterializedClass& materializedClass, - Operation* op, - StringRef kind, - StringRef materializer, - StringRef phase, - std::optional minChannelId, - std::optional orderKey, - Value payload = Value(), - const MessageVector* messages = nullptr) { - if (!op) - return; - - MLIRContext* context = op->getContext(); - int64_t traceId = state.nextCommunicationTraceId++; - auto indexType = IndexType::get(context); - op->setAttr(kRaptorCommTraceIdAttr, IntegerAttr::get(indexType, traceId)); - op->setAttr(kRaptorCommTraceKindAttr, StringAttr::get(context, kind)); - op->setAttr(kRaptorCommTracePhaseAttr, StringAttr::get(context, phase)); - op->setAttr(kRaptorCommTraceClassIdAttr, IntegerAttr::get(indexType, materializedClass.id)); - op->setAttr(kRaptorCommTraceClassKindAttr, - StringAttr::get(context, materializedClass.isBatch ? "batch" : "scalar")); - if (!materializer.empty()) - op->setAttr(kRaptorMaterializerAttr, StringAttr::get(context, materializer)); - if (minChannelId) - op->setAttr(kRaptorMinChannelIdAttr, IntegerAttr::get(indexType, *minChannelId)); - if (orderKey) - op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(indexType, *orderKey)); - if (auto ordinal = getOperationOrdinalInBlock(op)) - op->setAttr(kRaptorCommTraceBlockOrdinalAttr, IntegerAttr::get(indexType, *ordinal)); - op->setAttr(kRaptorCommTracePayloadAttr, - StringAttr::get(context, formatValueForTrace(payload, materializedClass.body))); - if (messages) - op->setAttr(kRaptorCommTraceMessagesAttr, StringAttr::get(context, formatMessagesForTrace(*messages))); - - Operation* prev = op->getPrevNode(); - Operation* next = op->getNextNode(); - op->setAttr(kRaptorCommTracePrevOpAttr, StringAttr::get(context, formatOperationForTrace(prev))); - op->setAttr(kRaptorCommTraceNextOpAttr, StringAttr::get(context, formatOperationForTrace(next))); - - if (!pimTraceCommunicationMaterialization) - return; - - llvm::errs() << "[raptor:comm-materializer] #" << traceId << " " << kind - << " via " << materializer << " phase=" << phase << " " - << formatClassForTrace(materializedClass); - if (minChannelId) - llvm::errs() << " min_channel=" << *minChannelId; - if (orderKey) - llvm::errs() << " order=" << *orderKey; - if (auto ordinal = getOperationOrdinalInBlock(op)) - llvm::errs() << " block_ordinal=" << *ordinal; - llvm::errs() << " payload=" << formatValueForTrace(payload, materializedClass.body); - if (messages) - llvm::errs() << " messages=" << formatMessagesForTrace(*messages); - llvm::errs() << " prev=" << formatOperationForTrace(prev) - << " next=" << formatOperationForTrace(next) << "\n"; -} - -void setInsertionPointForEarlyCommunication(MaterializerState& state, MaterializedClass& materializedClass) { - auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); - if (lateIt != state.firstLateCommunicationOps.end() && lateIt->second && lateIt->second->getBlock()) { - state.rewriter.setInsertionPoint(lateIt->second); - return; - } - - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); -} - -void setInsertionPointForLateCommunication(MaterializerState& state, MaterializedClass& materializedClass) { - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); -} - - -Operation* findLateScalarCommunicationInsertionPoint(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t minChannelId) { - Operation* terminator = materializedClass.body->getTerminator(); - auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); - Operation* firstLate = lateIt == state.firstLateCommunicationOps.end() ? nullptr : lateIt->second; - if (!firstLate || firstLate->getBlock() != materializedClass.body) - return terminator; - - bool inLateRegion = false; - for (Operation& op : *materializedClass.body) { - if (&op == terminator) - break; - - if (!inLateRegion) { - if (&op == firstLate) - inLateRegion = true; - else - continue; - } - - auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); - if (existingMinChannel && existingMinChannel.getInt() > minChannelId) - return &op; - } - - return terminator; -} - -void setInsertionPointForLateScalarCommunication(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t minChannelId) { - state.rewriter.setInsertionPoint( - findLateScalarCommunicationInsertionPoint(state, materializedClass, minChannelId)); -} - -void rememberLateCommunicationOp(MaterializerState& state, MaterializedClass& materializedClass, Operation* op) { - if (!op || op->getBlock() != materializedClass.body) - return; - - Operation*& firstLate = state.firstLateCommunicationOps[materializedClass.id]; - if (!firstLate || firstLate->getBlock() != materializedClass.body || op->isBeforeInBlock(firstLate)) - firstLate = op; -} - - - -constexpr const char kMinCommunicationChannelIdAttr[] = "raptor.min_channel_id"; - -std::optional getConstantIndexValue(Value value) { - APInt constant; - if (matchPattern(value, m_ConstantInt(&constant))) - return constant.getSExtValue(); - return std::nullopt; -} - -std::optional getCommunicationChannelId(Operation& op) { - if (auto attr = op.getAttrOfType(kMinCommunicationChannelIdAttr)) - return attr.getInt(); - - if (auto send = dyn_cast(&op)) - return getConstantIndexValue(send.getChannelId()); - if (auto receive = dyn_cast(&op)) - return getConstantIndexValue(receive.getChannelId()); - - return std::nullopt; -} - -int64_t getMinimumCommunicationChannelId(const MessageVector& messages) { - assert(!messages.empty() && "expected at least one message"); - return *std::min_element(messages.channelIds.begin(), messages.channelIds.end()); -} - -void markCommunicationChannelId(Operation* op, int64_t channelId) { - if (!op) - return; - op->setAttr(kMinCommunicationChannelIdAttr, - IntegerAttr::get(IntegerType::get(op->getContext(), 64), channelId)); -} - -Operation* getSameBlockDefiningOp(Value value, Block* block) { - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != block) - return nullptr; - return definingOp; -} - - -bool valueDependsOnChannelReceive(Value root) { - SmallVector worklist; - DenseSet visitedValues; - DenseSet visitedOps; - worklist.push_back(root); - - auto visitOperand = [&](Value value) { - if (value && visitedValues.insert(value).second) - worklist.push_back(value); - }; - - while (!worklist.empty()) { - Value value = worklist.pop_back_val(); - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || !visitedOps.insert(definingOp).second) - continue; - - if (isa(definingOp)) - return true; - - for (Value operand : definingOp->getOperands()) - visitOperand(operand); - - for (Region& region : definingOp->getRegions()) { - for (Block& block : region) { - for (Operation& nested : block) { - for (Value operand : nested.getOperands()) - visitOperand(operand); - } - } - } - } - - return false; -} - -bool shouldDelayScalarSendUntilAfterReceives(Value payload, int32_t sourceCoreId, int32_t targetCoreId) { - if (sourceCoreId <= targetCoreId) - return false; - return valueDependsOnChannelReceive(payload); -} - -void partitionScalarMessagesByReceiveDependency(Value payload, - const MessageVector& messages, - MessageVector& earlyMessages, - MessageVector& lateMessages) { - for (size_t i = 0, e = messages.size(); i < e; ++i) { - MessageVector& bucket = shouldDelayScalarSendUntilAfterReceives( - payload, messages.sourceCoreIds[i], messages.targetCoreIds[i]) - ? lateMessages - : earlyMessages; - bucket.append(messages.channelIds[i], messages.sourceCoreIds[i], messages.targetCoreIds[i]); - } -} - -void setInsertionPointForScalarSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - int64_t minChannelId, - bool late) { - if (late) { - setInsertionPointForLateScalarCommunication(state, sourceClass, minChannelId); - return; - } - - setInsertionPointForScalarCommunication( - state, sourceClass, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); -} - - -void appendScalarSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { - assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); - - bool late = shouldDelayScalarSendUntilAfterReceives(payload, sourceCoreId, targetCoreId); - int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); - if (pimMaterializeScalarFanoutGlobalOrder) - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, orderKey, channelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - else - setInsertionPointForScalarSend(state, sourceClass, payload, channelId, late); - Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); - Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); - Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); - auto send = SpatChannelSendOp::create( - state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); - markScalarCommunication(send.getOperation(), channelId, "appendScalarSend"); - markScalarCommunicationOrder(send.getOperation(), orderKey); - MessageVector traceMessages; - traceMessages.append(channelId, sourceCoreId, targetCoreId); - annotateCommunicationMaterialization(state, - sourceClass, - send.getOperation(), - "send", - "appendScalarSend", - late ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), - channelId, - orderKey, - payload, - &traceMessages); - if (late && !pimMaterializeScalarFanoutGlobalOrder) - rememberLateCommunicationOp(state, sourceClass, send.getOperation()); -} - -LogicalResult emitScalarSendLoopAtInsertionPoint(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - int64_t minChannelId, - int64_t orderKey, - Location loc) { - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - - auto sendLoop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - return success(); - }); - if (failed(sendLoop)) - return failure(); - markScalarCommunication(sendLoop->loop.getOperation(), minChannelId, "appendScalarSendLoop"); - markScalarCommunicationOrder(sendLoop->loop.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - sourceClass, - sendLoop->loop.getOperation(), - "send-loop", - "appendScalarSendLoop", - "loop", - minChannelId, - orderKey, - payload, - &messages); - return success(); -} - -LogicalResult appendScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); - assert(messages.size() > 1 && "send loop is only useful for multiple sends"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - MessageVector orderedMessages = reorderScalarSendMessagesByChannel(messages); - if (pimMaterializeScalarFanoutGlobalOrder) { - for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) - appendScalarSend(state, - sourceClass, - payload, - orderedMessages.channelIds[index], - orderedMessages.sourceCoreIds[index], - orderedMessages.targetCoreIds[index], - loc); - return success(); - } - - int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - return emitScalarSendLoopAtInsertionPoint(state, sourceClass, payload, orderedMessages, minChannelId, orderKey, loc); -} - - -FailureOr buildProjectedPackedPayload(MaterializerState& state, - MaterializedClass& targetClass, - Value fullPayload, - const ProjectedTransferDescriptor& descriptor, - Value messageIndex, - Location loc) { - if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - if (descriptor.layout.payloadFragmentCount == 1) - return targetClass.op->emitError("projected packed payload builder expects a packed payload"); - - Value init = tensor::EmptyOp::create( - state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) - .getResult(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {init}, - [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value payloadFragmentCount = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); - FailureOr localMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); - if (failed(localMessageIndex)) - return failure(); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localMessageIndex, payloadFragmentCount).getResult(); - Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - - FailureOr> fragmentOffsets = - buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc); - if (failed(fragmentOffsets)) - return failure(); - FailureOr fragment = createStaticExtractSliceInClass( - state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); - if (failed(fragment)) - return failure(); - - FailureOr packedOffset = scaleIndexByDim0SizeInClass( - state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); - if (failed(packedOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -FailureOr buildProjectedPayloadForMessage(MaterializerState& state, - MaterializedClass& targetClass, - Value fullPayload, - const ProjectedTransferDescriptor& descriptor, - Value messageIndex, - Location loc) { - if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - - if (descriptor.layout.payloadFragmentCount == 1) { - FailureOr> fragmentOffsets = - buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, messageIndex, loc); - if (failed(fragmentOffsets)) - return failure(); - return createStaticExtractSliceInClass( - state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); - } - - return buildProjectedPackedPayload(state, targetClass, fullPayload, descriptor, messageIndex, loc); -} - -LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ProjectedTransferDescriptor& descriptor, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - SmallVector messageOrder = getScalarSendChannelOrder(messages); - MessageVector orderedMessages = reorderMessages(messages, messageOrder); - ProjectedTransferDescriptor orderedDescriptor = reorderProjectedDescriptorByMessageOrder(descriptor, messageOrder); - if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, orderedDescriptor))) - return failure(); - if (failed(verifyProjectedSendDescriptor(sourceClass.op, orderedDescriptor, orderedMessages))) - return failure(); - - int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - - if (orderedMessages.size() == 1 || pimMaterializeScalarFanoutGlobalOrder) { - for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) { - int64_t channel = orderedMessages.channelIds[index]; - int32_t sourceCore = orderedMessages.sourceCoreIds[index]; - int32_t targetCore = orderedMessages.targetCoreIds[index]; - int64_t localOrderKey = computeBlockingCommunicationOrderKey(sourceCore, targetCore, channel); - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, localOrderKey, channel, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channel); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCore); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCore); - Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(index)); - FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, messageIndex, loc); - if (failed(sendPayload)) - return failure(); - auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); - markScalarCommunication(send.getOperation(), channel, "appendProjectedScalarSendLoop.single"); - markScalarCommunicationOrder(send.getOperation(), localOrderKey); - MessageVector traceMessages; - traceMessages.append(channel, sourceCore, targetCore); - annotateCommunicationMaterialization(state, - sourceClass, - send.getOperation(), - "send", - "appendProjectedScalarSendLoop.single", - "projected-single", - channel, - localOrderKey, - *sendPayload, - &traceMessages); - } - return success(); - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(orderedMessages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - - auto projectedSendLoop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedChannelId(state, sourceClass.op, orderedMessages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, orderedMessages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, orderedMessages, index, loc); - FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, index, loc); - if (failed(sendPayload)) - return failure(); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); - return success(); - }); - if (failed(projectedSendLoop)) - return failure(); - markScalarCommunication(projectedSendLoop->loop.getOperation(), minChannelId, "appendProjectedScalarSendLoop.loop"); - markScalarCommunicationOrder(projectedSendLoop->loop.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - sourceClass, - projectedSendLoop->loop.getOperation(), - "send-loop", - "appendProjectedScalarSendLoop.loop", - "projected-loop", - minChannelId, - orderKey, - payload, - &orderedMessages); - return success(); -} - - -LogicalResult appendSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one send"); - - if (sourceClass.isBatch) { - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - - Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); - Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); - Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); - auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - int64_t minChannelId = getMinimumChannelId(messages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); - markScalarCommunication(send.getOperation(), minChannelId, "appendSend.batch"); - markScalarCommunicationOrder(send.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - sourceClass, - send.getOperation(), - "send", - "appendSend.batch", - "batch-lane-indexed", - minChannelId, - orderKey, - payload, - &messages); - return success(); - } - - if (messages.size() == 1) { - appendScalarSend(state, - sourceClass, - payload, - messages.channelIds.front(), - messages.sourceCoreIds.front(), - messages.targetCoreIds.front(), - loc); - return success(); - } - - return appendScalarSendLoop(state, sourceClass, payload, messages, loc); -} - -Value appendScalarReceive(MaterializerState& state, - MaterializedClass& targetClass, - Type type, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc, - bool lateReceive = false) { - assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); - - int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); - if (lateReceive) - setInsertionPointForLateScalarCommunication(state, targetClass, channelId); - else - setInsertionPointForScalarCommunicationOrder(state, targetClass, orderKey, channelId); - Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); - Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); - Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); - auto receive = SpatChannelReceiveOp::create( - state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); - markScalarCommunication(receive.getOperation(), channelId, - lateReceive ? "appendScalarReceive.late" : "appendScalarReceive"); - markScalarCommunicationOrder(receive.getOperation(), orderKey); - MessageVector traceMessages; - traceMessages.append(channelId, sourceCoreId, targetCoreId); - annotateCommunicationMaterialization(state, - targetClass, - receive.getOperation(), - "receive", - lateReceive ? "appendScalarReceive.late" : "appendScalarReceive", - lateReceive ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), - channelId, - orderKey, - Value(), - &traceMessages); - return receive.getOutput(); -} - - -Value appendReceive( - MaterializerState& state, - MaterializedClass& targetClass, - Type type, - const MessageVector& messages, - Location loc, - bool lateReceive = false) { - assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one receive"); - - if (lateReceive) - setInsertionPointForLateScalarCommunication(state, targetClass, getMinimumChannelId(messages.channelIds)); - else - setInsertionPointForEarlyCommunication(state, targetClass); - - if (targetClass.isBatch) { - Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); - Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); - Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); - auto receive = SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId); - int64_t minChannelId = getMinimumChannelId(messages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); - markScalarCommunication(receive.getOperation(), minChannelId, "appendReceive.batch"); - markScalarCommunicationOrder(receive.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - targetClass, - receive.getOperation(), - "receive", - "appendReceive.batch", - lateReceive ? "late-batch" : "early-batch", - minChannelId, - orderKey, - Value(), - &messages); - return receive.getOutput(); - } - - assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); - return appendScalarReceive(state, - targetClass, - type, - messages.channelIds.front(), - messages.sourceCoreIds.front(), - messages.targetCoreIds.front(), - loc, - lateReceive); -} - -Value appendScalarReceiveAtCurrentInsertionPoint(MaterializerState& state, - MaterializedClass& targetClass, - Type type, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { - assert(!targetClass.isBatch && "demand scalar receive expects a scalar target class"); - - int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); - Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); - Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); - Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); - auto receive = SpatChannelReceiveOp::create( - state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); - markScalarCommunication(receive.getOperation(), channelId, "appendScalarReceive.demand"); - markScalarCommunicationOrder(receive.getOperation(), orderKey); - MessageVector traceMessages; - traceMessages.append(channelId, sourceCoreId, targetCoreId); - annotateCommunicationMaterialization(state, - targetClass, - receive.getOperation(), - "receive", - "appendScalarReceive.demand", - "demand", - channelId, - orderKey, - Value(), - &traceMessages); - return receive.getOutput(); -} - -std::optional lookupPendingScalarReceiveIndex(MaterializerState& state, - ProducerKey key, - ClassId targetClassId) { - auto keyIt = state.pendingScalarReceiveLookup.find(key); - if (keyIt == state.pendingScalarReceiveLookup.end()) - return std::nullopt; - - auto classIt = keyIt->second.find(targetClassId); - if (classIt == keyIt->second.end()) - return std::nullopt; - return classIt->second; -} - -void recordPendingScalarReceive(MaterializerState& state, - ClassId targetClassId, - ArrayRef keys, - Type receiveType, - const MessageVector& messages, - Location loc) { - if (keys.empty()) - return; - - if (lookupPendingScalarReceiveIndex(state, keys.front(), targetClassId)) - return; - - size_t recordIndex = state.pendingScalarReceives.size(); - state.pendingScalarReceives.emplace_back(keys, targetClassId, receiveType, messages, loc); - - for (ProducerKey key : keys) - state.pendingScalarReceiveLookup[key][targetClassId] = recordIndex; -} - -FailureOr materializePendingScalarReceive(MaterializerState& state, - MaterializedClass& targetClass, - size_t recordIndex, - Location loc) { - if (recordIndex >= state.pendingScalarReceives.size()) - return targetClass.op->emitError("pending scalar receive index is out of bounds"); - - PendingScalarReceiveRecord& record = state.pendingScalarReceives[recordIndex]; - if (record.targetClassId != targetClass.id) - return targetClass.op->emitError("pending scalar receive target class mismatch"); - - if (record.materialized) - return record.value; - - if (targetClass.isBatch) - return targetClass.op->emitError("pending scalar receive cannot materialize into a batch class"); - if (record.messages.size() != 1) - return targetClass.op->emitError("pending scalar receive expected exactly one scalar message"); - - Location receiveLoc = loc; - Value received = appendScalarReceiveAtCurrentInsertionPoint(state, - targetClass, - record.receiveType, - record.messages.channelIds.front(), - record.messages.sourceCoreIds.front(), - record.messages.targetCoreIds.front(), - receiveLoc); - record.materialized = true; - record.value = received; - - for (ProducerKey key : record.keys) - state.availableValues.record(key, targetClass.id, received); - - return received; -} - - -LogicalResult materializePendingScalarReceivesForWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey wholeBatchKey, - Location loc) { - if (targetClass.isBatch || !isWholeBatchProducerKey(wholeBatchKey)) - return success(); - - SmallVector pendingIndices; - for (auto [recordIndex, record] : llvm::enumerate(state.pendingScalarReceives)) { - if (record.targetClassId != targetClass.id || record.materialized) - continue; - - bool contributesToWholeBatch = llvm::any_of(record.keys, [&](ProducerKey fragmentKey) { - return containsProducerKey(wholeBatchKey, fragmentKey); - }); - if (contributesToWholeBatch) - pendingIndices.push_back(recordIndex); - } - - if (pendingIndices.empty()) - return success(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - for (size_t recordIndex : pendingIndices) { - FailureOr received = materializePendingScalarReceive(state, targetClass, recordIndex, loc); - if (failed(received)) - return failure(); - } - - return success(); -} - -LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - ArrayRef keys, - Type fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds) { - if (!sourceClass.isBatch) - return sourceClass.op->emitError("lazy packed scalar receives expect a batch source class"); - - if (targetClass.isBatch) - return targetClass.op->emitError("lazy packed scalar receives expect a scalar target class"); - - if (keys.empty()) - return sourceClass.op->emitError("lazy packed scalar receive expects at least one producer key"); - - if (keys.size() != sourceClass.cpus.size()) - return sourceClass.op->emitError("lazy packed scalar receive expects one producer key per source lane"); - - MessageVector messages; - messages.append(channelIds, sourceCoreIds, targetCoreIds); - if (failed(messages.verify(targetClass.op))) - return failure(); - - if (keys.size() != messages.size()) - return targetClass.op->emitError("lazy packed scalar receive metadata is inconsistent"); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return targetClass.op->emitError("lazy packed scalar receive expects a static ranked fragment type"); - - if (failed(verifyPackableFragmentType( - targetClass.op, fragmentType, keys.size(), "cannot create lazy packed scalar receive type"))) - return failure(); - - Operation* sourceOp = keys.front().instance.op; - size_t resultIndex = keys.front().resultIndex; - - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != resultIndex) - return sourceClass.op->emitError("lazy packed scalar receive expects one producer result"); - - if (key.instance.laneCount != 1) - return sourceClass.op->emitError("lazy packed scalar receive expects one lane per producer key"); - } - - PackedScalarRunValue packedRun; - packedRun.targetClass = targetClass.id; - packedRun.sourceOp = sourceOp; - packedRun.resultIndex = resultIndex; - packedRun.kind = PackedScalarRunKind::DeferredReceive; - packedRun.fragmentType = rankedFragmentType; - - packedRun.messages = std::move(messages); - - PackedScalarRunSlot slot; - llvm::append_range(slot.keys, keys); - packedRun.slots.push_back(std::move(slot)); - - if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) - return failure(); - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -struct ScalarSourceReceivePlan { - ClassId targetClass = 0; - MessageVector messages; - Type receiveType; - Operation* projectedExtractOp = nullptr; - ProjectedFragmentLayout projectedLayout; - std::optional projectedDescriptor; -}; - -struct ProjectedScalarSendGroup { - MessageVector messages; - ProjectedTransferDescriptor descriptor; -}; - -struct ScalarSourceFanoutPlan { - SmallVector receivePlans; - std::optional ordinaryMessages; - SmallVector projectedSendGroups; -}; - -bool hasSameProjectedSendCompatibility(const ProjectedTransferDescriptor& lhs, const ProjectedTransferDescriptor& rhs) { - return lhs.layout.fragmentType == rhs.layout.fragmentType && lhs.layout.fragmentShape == rhs.layout.fragmentShape - && lhs.layout.fragmentsPerLogicalSlot == rhs.layout.fragmentsPerLogicalSlot - && lhs.layout.payloadFragmentCount == rhs.layout.payloadFragmentCount - && lhs.layout.loopLowerBounds == rhs.layout.loopLowerBounds && lhs.layout.loopSteps == rhs.layout.loopSteps - && lhs.layout.loopTripCounts == rhs.layout.loopTripCounts && lhs.payloadType == rhs.payloadType; -} - -SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { - SmallVector destinations; - - for (ProducerKey key : keys) - for (ClassId destinationClass : getDestinationClasses(state, key)) - destinations.push_back(destinationClass); - - llvm::sort(destinations); - destinations.erase(std::unique(destinations.begin(), destinations.end()), destinations.end()); - return destinations; -} - -FailureOr buildScalarSourceFanoutPlan(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - ArrayRef destinationClasses, - Value payload) { - assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); - - auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id"); - if (failed(sourceCpu)) - return failure(); - - ScalarSourceFanoutPlan fanoutPlan; - fanoutPlan.receivePlans.reserve(destinationClasses.size()); - - const auto getProjectedDescriptor = - [&](ClassId destinationClass) -> FailureOr> { - MaterializedClass& targetClass = state.classes[destinationClass]; - if (!targetClass.isBatch) { - bool hasAnyProjectedDescriptor = llvm::any_of(keys, [&](ProducerKey key) { - auto producerIt = state.projectedTransfers.find(key); - return producerIt != state.projectedTransfers.end() && producerIt->second.count(destinationClass) != 0; - }); - - std::optional descriptor = collectScalarTargetProjectedDescriptor( - state, targetClass, keys, /*requirePackedRunOffsetCountMatch=*/keys.size() > 1); - if (hasAnyProjectedDescriptor && !descriptor) - return targetClass.op->emitError("incomplete scalar projected transfer descriptor for local run"); - return descriptor; - } - - if (keys.size() != 1) - return std::optional {}; - - auto producerIt = state.projectedTransfers.find(keys.front()); - if (producerIt == state.projectedTransfers.end()) - return std::optional {}; - - auto descriptorIt = producerIt->second.find(destinationClass); - if (descriptorIt == producerIt->second.end()) - return std::optional {}; - - const ProjectedTransferDescriptor& descriptor = descriptorIt->second; - if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - if (descriptor.fragmentOffsets.size() - != targetClass.cpus.size() * static_cast(descriptor.layout.payloadFragmentCount)) - return targetClass.op->emitError("inconsistent batch projected transfer descriptor"); - - return std::optional {descriptor}; - }; - - for (ClassId destinationClass : destinationClasses) { - if (destinationClass == sourceClass.id) - continue; - - MaterializedClass& targetClass = state.classes[destinationClass]; - - ScalarSourceReceivePlan receivePlan; - receivePlan.targetClass = destinationClass; - receivePlan.receiveType = payload.getType(); - - auto appendMessage = [&](CpuId targetCpu) -> LogicalResult { - auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id"); - if (failed(checkedTargetCpu)) - return failure(); - int64_t channelId = state.nextChannelId++; - - receivePlan.messages.append(channelId, *sourceCpu, *checkedTargetCpu); - return success(); - }; - - if (!targetClass.isBatch) { - if (failed(appendMessage(targetClass.cpus.front()))) - return failure(); - } - else { - for (CpuId targetCpu : targetClass.cpus) - if (failed(appendMessage(targetCpu))) - return failure(); - } - - FailureOr> descriptor = getProjectedDescriptor(destinationClass); - if (failed(descriptor)) - return failure(); - - if (*descriptor) { - const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; - - if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) { - if (!fanoutPlan.ordinaryMessages) - fanoutPlan.ordinaryMessages = MessageVector {}; - fanoutPlan.ordinaryMessages->append( - receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - fanoutPlan.receivePlans.push_back(std::move(receivePlan)); - continue; - } - - receivePlan.receiveType = projectedDescriptor.payloadType; - receivePlan.projectedExtractOp = projectedDescriptor.extractOp; - receivePlan.projectedLayout = projectedDescriptor.layout; - receivePlan.projectedDescriptor = projectedDescriptor; - - auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { - return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); - }); - if (groupIt == fanoutPlan.projectedSendGroups.end()) { - ProjectedScalarSendGroup group; - group.descriptor.layout = projectedDescriptor.layout; - group.descriptor.payloadType = projectedDescriptor.payloadType; - fanoutPlan.projectedSendGroups.push_back(std::move(group)); - groupIt = std::prev(fanoutPlan.projectedSendGroups.end()); - } - - groupIt->messages.append( - receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - llvm::append_range(groupIt->descriptor.fragmentOffsets, projectedDescriptor.fragmentOffsets); - } - else { - if (!fanoutPlan.ordinaryMessages) - fanoutPlan.ordinaryMessages = MessageVector {}; - fanoutPlan.ordinaryMessages->append( - receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - } - - fanoutPlan.receivePlans.push_back(std::move(receivePlan)); - } - - for (ProjectedScalarSendGroup& group : fanoutPlan.projectedSendGroups) { - if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, group.descriptor))) - return failure(); - if (failed(verifyProjectedSendDescriptor(sourceClass.op, group.descriptor, group.messages))) - return failure(); - } - - return fanoutPlan; -} - -LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ScalarSourceFanoutPlan& plan, - Location loc) { - if (plan.ordinaryMessages && failed(appendSend(state, sourceClass, payload, *plan.ordinaryMessages, loc))) - return failure(); - - for (const ProjectedScalarSendGroup& group : plan.projectedSendGroups) - if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, group.descriptor, group.messages, loc))) - return failure(); - - return success(); -} - - -struct GloballyOrderedScalarFanoutEvent { - size_t receivePlanIndex = 0; - int64_t minChannelId = 0; - int64_t orderKey = 0; - int32_t minSourceCoreId = 0; - int32_t minTargetCoreId = 0; -}; - -GloballyOrderedScalarFanoutEvent makeGloballyOrderedScalarFanoutEvent(size_t receivePlanIndex, - const ScalarSourceReceivePlan& plan) { - assert(!plan.messages.empty() && "expected a communication event with at least one message"); - GloballyOrderedScalarFanoutEvent event; - event.receivePlanIndex = receivePlanIndex; - event.minChannelId = plan.messages.channelIds.front(); - event.orderKey = getMinimumBlockingCommunicationOrderKey(plan.messages); - event.minSourceCoreId = plan.messages.sourceCoreIds.front(); - event.minTargetCoreId = plan.messages.targetCoreIds.front(); - - for (size_t index = 1, end = plan.messages.size(); index < end; ++index) { - event.minChannelId = std::min(event.minChannelId, plan.messages.channelIds[index]); - event.minSourceCoreId = std::min(event.minSourceCoreId, plan.messages.sourceCoreIds[index]); - event.minTargetCoreId = std::min(event.minTargetCoreId, plan.messages.targetCoreIds[index]); - } - - return event; -} - -SmallVector -collectGloballyOrderedScalarFanoutEvents(const ScalarSourceFanoutPlan& plan) { - SmallVector events; - events.reserve(plan.receivePlans.size()); - - for (auto [index, receivePlan] : llvm::enumerate(plan.receivePlans)) - if (!receivePlan.messages.empty()) - events.push_back(makeGloballyOrderedScalarFanoutEvent(index, receivePlan)); - - llvm::sort(events, [](const GloballyOrderedScalarFanoutEvent& lhs, - const GloballyOrderedScalarFanoutEvent& rhs) { - if (lhs.orderKey != rhs.orderKey) - return lhs.orderKey < rhs.orderKey; - if (lhs.minChannelId != rhs.minChannelId) - return lhs.minChannelId < rhs.minChannelId; - if (lhs.minSourceCoreId != rhs.minSourceCoreId) - return lhs.minSourceCoreId < rhs.minSourceCoreId; - if (lhs.minTargetCoreId != rhs.minTargetCoreId) - return lhs.minTargetCoreId < rhs.minTargetCoreId; - return lhs.receivePlanIndex < rhs.receivePlanIndex; - }); - - return events; -} - -LogicalResult emitGloballyOrderedScalarFanoutSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ScalarSourceReceivePlan& plan, - Location loc) { - if (plan.projectedDescriptor) - return appendProjectedScalarSendLoop(state, sourceClass, payload, *plan.projectedDescriptor, plan.messages, loc); - - return appendSend(state, sourceClass, payload, plan.messages, loc); -} - -bool isMaterializedBlockingCommunication(Operation& op) { - return isa(&op) || op.hasAttr(kRaptorMinChannelIdAttr) - || op.hasAttr(kRaptorCommOrderAttr); -} - -bool payloadIsAvailableOnlyAfterPriorCommunication(Value payload, MaterializedClass& sourceClass) { - Operation* lowerBound = getPayloadDefiningOpInClassBlock(payload, sourceClass); - if (!lowerBound) - return false; - - bool sawPriorCommunication = false; - Operation* terminator = sourceClass.body->getTerminator(); - for (Operation& op : *sourceClass.body) { - if (&op == terminator) - break; - - if (&op == lowerBound) - return sawPriorCommunication || isMaterializedBlockingCommunication(op); - - if (isMaterializedBlockingCommunication(op)) - sawPriorCommunication = true; - } - - return sawPriorCommunication; -} - -bool shouldPlaceMatchingScalarFanoutReceiveLate(MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages) { - if (payloadIsAvailableOnlyAfterPriorCommunication(payload, sourceClass)) - return true; - - for (size_t index = 0, end = messages.size(); index < end; ++index) - if (shouldDelayScalarSendUntilAfterReceives( - payload, messages.sourceCoreIds[index], messages.targetCoreIds[index])) - return true; - return false; -} - -LogicalResult emitGloballyOrderedScalarSourceFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - const ScalarSourceFanoutPlan& plan, - Location loc) { - SmallVector events = collectGloballyOrderedScalarFanoutEvents(plan); - - for (const GloballyOrderedScalarFanoutEvent& event : events) { - const ScalarSourceReceivePlan& planEntry = plan.receivePlans[event.receivePlanIndex]; - MaterializedClass& targetClass = state.classes[planEntry.targetClass]; - - if (failed(emitGloballyOrderedScalarFanoutSend(state, sourceClass, payload, planEntry, loc))) - return failure(); - - if (!targetClass.isBatch && !planEntry.projectedExtractOp) { - recordPendingScalarReceive(state, targetClass.id, keys, planEntry.receiveType, planEntry.messages, loc); - continue; - } - - bool lateReceive = shouldPlaceMatchingScalarFanoutReceiveLate(sourceClass, payload, planEntry.messages); - Value received = appendReceive(state, targetClass, planEntry.receiveType, planEntry.messages, loc, lateReceive); - - if (planEntry.projectedExtractOp) { - state.projectedExtractReplacements[planEntry.projectedExtractOp][planEntry.targetClass] = - ProjectedExtractReplacement {received, planEntry.projectedLayout}; - continue; - } - - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); - } - - return success(); -} - -LogicalResult emitScalarSourceCommunication( - MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { - assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); - - for (ProducerKey key : keys) - state.availableValues.record(key, sourceClass.id, payload); - - SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); - auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); - if (failed(fanoutPlan)) - return failure(); - if (pimMaterializeScalarFanoutGlobalOrder) - return emitGloballyOrderedScalarSourceFanout(state, sourceClass, keys, payload, *fanoutPlan, loc); - - if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) - return failure(); - - for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { - MaterializedClass& targetClass = state.classes[plan.targetClass]; - - Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); - - if (plan.projectedExtractOp) { - state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = - ProjectedExtractReplacement {received, plan.projectedLayout}; - continue; - } - - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); - } - - return success(); -} - -FailureOr emitOrderedBatchToBatchCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(sourceClass.isBatch && targetClass.isBatch && "ordered batch communication expects two batch classes"); - if (failed(messages.verify(sourceClass.op))) - return failure(); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape()) - return sourceClass.op->emitError("ordered batch communication expects a static ranked tensor payload"); - - auto makeEmpty = [&](MaterializedClass& materializedClass) -> Value { - return tensor::EmptyOp::create( - state.rewriter, loc, payloadType.getShape(), payloadType.getElementType()) - .getResult(); - }; - - setInsertionPointForEarlyCommunication(state, sourceClass); - Value sendChannelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); - Value sendSourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); - Value sendTargetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); - Value sendEarlyCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sle, - sendSourceCoreId, - sendTargetCoreId) - .getResult(); - auto earlySendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendEarlyCond, /*withElseRegion=*/false); - state.rewriter.setInsertionPoint(earlySendIf.thenBlock()->getTerminator()); - auto earlySend = SpatChannelSendOp::create( - state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); - markScalarCommunication( - earlySend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlySend"); - - setInsertionPointForLateCommunication(state, sourceClass); - Value sendLateCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sgt, - sendSourceCoreId, - sendTargetCoreId) - .getResult(); - auto lateSendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendLateCond, /*withElseRegion=*/false); - rememberLateCommunicationOp(state, sourceClass, lateSendIf.getOperation()); - state.rewriter.setInsertionPoint(lateSendIf.thenBlock()->getTerminator()); - auto lateSend = SpatChannelSendOp::create( - state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); - markScalarCommunication( - lateSend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateSend"); - - setInsertionPointForEarlyCommunication(state, targetClass); - Value recvChannelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); - Value recvSourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); - Value recvTargetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); - Value recvEarlyCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sle, - recvSourceCoreId, - recvTargetCoreId) - .getResult(); - auto earlyReceiveIf = scf::IfOp::create( - state.rewriter, loc, TypeRange {payload.getType()}, recvEarlyCond, /*withElseRegion=*/true); - Operation* earlyThenYield = earlyReceiveIf.thenBlock()->getTerminator(); - state.rewriter.setInsertionPoint(earlyThenYield); - auto earlyReceive = SpatChannelReceiveOp::create( - state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); - markScalarCommunication( - earlyReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlyReceive"); - Value earlyReceived = earlyReceive.getOutput(); - state.rewriter.modifyOpInPlace(earlyThenYield, [&] { earlyThenYield->setOperands(ValueRange {earlyReceived}); }); - Operation* earlyElseYield = earlyReceiveIf.elseBlock()->getTerminator(); - state.rewriter.setInsertionPoint(earlyElseYield); - Value empty = makeEmpty(targetClass); - state.rewriter.modifyOpInPlace(earlyElseYield, [&] { earlyElseYield->setOperands(ValueRange {empty}); }); - - setInsertionPointForLateCommunication(state, targetClass); - Value recvLateCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sgt, - recvSourceCoreId, - recvTargetCoreId) - .getResult(); - auto lateReceiveIf = scf::IfOp::create( - state.rewriter, loc, TypeRange {payload.getType()}, recvLateCond, /*withElseRegion=*/true); - rememberLateCommunicationOp(state, targetClass, lateReceiveIf.getOperation()); - Operation* lateThenYield = lateReceiveIf.thenBlock()->getTerminator(); - state.rewriter.setInsertionPoint(lateThenYield); - auto lateReceive = SpatChannelReceiveOp::create( - state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); - markScalarCommunication( - lateReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateReceive"); - Value lateReceived = lateReceive.getOutput(); - state.rewriter.modifyOpInPlace(lateThenYield, [&] { lateThenYield->setOperands(ValueRange {lateReceived}); }); - Operation* lateElseYield = lateReceiveIf.elseBlock()->getTerminator(); - state.rewriter.modifyOpInPlace( - lateElseYield, [&] { lateElseYield->setOperands(ValueRange {earlyReceiveIf.getResult(0)}); }); - - return lateReceiveIf.getResult(0); -} - -LogicalResult emitClassToClassCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - ArrayRef keys, - Value payload, - Location loc) { - if (sourceClass.id == targetClass.id) { - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, payload); - return success(); - } - - if (!sourceClass.isBatch) - return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); - - if (!targetClass.isBatch) { - MessageVector messages; - messages.channelIds.reserve(sourceClass.cpus.size()); - messages.sourceCoreIds.reserve(sourceClass.cpus.size()); - messages.targetCoreIds.reserve(sourceClass.cpus.size()); - - auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); - if (failed(targetCpu)) - return failure(); - for (CpuId sourceCpu : sourceClass.cpus) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); - if (failed(checkedSourceCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *targetCpu); - } - - if (failed(appendSend(state, sourceClass, payload, messages, loc))) - return failure(); - return registerLazyPackedScalarReceives(state, - sourceClass, - targetClass, - keys, - payload.getType(), - messages.channelIds, - messages.sourceCoreIds, - messages.targetCoreIds); - } - - if (sourceClass.cpus.size() != targetClass.cpus.size()) - return sourceClass.op->emitError( - "cannot materialize batch communication between equivalence classes of different sizes"); - - MessageVector messages; - messages.channelIds.reserve(sourceClass.cpus.size()); - messages.sourceCoreIds.reserve(sourceClass.cpus.size()); - messages.targetCoreIds.reserve(targetClass.cpus.size()); - - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id"); - if (failed(checkedSourceCpu)) - return failure(); - auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id"); - if (failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - } - - FailureOr received = - emitOrderedBatchToBatchCommunication(state, sourceClass, targetClass, payload, messages, loc); - if (failed(received)) - return failure(); - - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, *received); - - return success(); -} - -LogicalResult -setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { - auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); - if (resultIt == sourceClass.hostOutputToResultIndex.end()) - return sourceClass.op->emitError("missing host result slot for materialized output") - << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") - << " hostOutputs=" << sourceClass.hostOutputs.size() - << " originalDef=" << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() - : StringRef("")); - - unsigned resultIndex = resultIt->second; - if (payload.getType() != originalOutput.getType()) - return sourceClass.op->emitError("cannot set host output from fragment payload without projection") - << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); - - if (!sourceClass.isBatch) { - auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); - if (!yieldOp) - return sourceClass.op->emitError("expected spat.yield terminator in materialized compute"); - if (resultIndex >= yieldOp.getNumOperands()) - return sourceClass.op->emitError("host result index out of range for materialized compute"); - - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); - state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); - return success(); - } - - auto batch = cast(sourceClass.op); - auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); - if (!inParallelOp) - return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape()) - return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); - - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); - - auto outputArg = batch.getOutputArgument(resultIndex); - if (!outputArg) - return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); - - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - - createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); - state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); - return success(); -} - -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); - -LogicalResult emitProjectedBatchHostOutput(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value originalOutput, - Value payload, - Location loc) { - if (!sourceClass.isBatch) - return sourceClass.op->emitError("projected batch host publication expects a batch owner class"); - auto batch = cast(sourceClass.op); - - auto ownerIt = sourceClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == sourceClass.hostOutputToResultIndex.end()) - return sourceClass.op->emitError("missing host result slot for projected batch output"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return sourceClass.op->emitError("projected batch host publication expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for publication"); - - auto sourceLaneArg = sourceBatch.getLaneArgument(); - if (!sourceLaneArg) - return sourceBatch.emitOpError("missing source compute_batch lane argument for host projection"); - - // The projection coordinates are part of the source batch publication. - // Build any affine/index helper ops in the source batch body, not at the - // caller's current insertion point. Otherwise a scalar host-owner body may - // accidentally capture the source scheduled_compute_batch lane argument. - OpBuilder::InsertionGuard projectionGuard(state.rewriter); - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - - FailureOr projectionLaneValue = createProjectionLaneValueForKeys(state, sourceClass, keys, loc); - if (failed(projectionLaneValue)) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(projection->getMixedOffsets().size()); - sizes.reserve(projection->getMixedSizes().size()); - strides.reserve(projection->getMixedStrides().size()); - - for (OpFoldResult offset : projection->getMixedOffsets()) { - FailureOr remapped = - remapProjectionIndexLike(state, sourceClass.op, offset, *sourceLaneArg, *projectionLaneValue, loc); - if (failed(remapped)) - return sourceClass.op->emitError("failed to remap projected batch host offsets"); - offsets.push_back(*remapped); - } - for (OpFoldResult size : projection->getMixedSizes()) { - FailureOr remapped = - remapProjectionIndexLike(state, sourceClass.op, size, *sourceLaneArg, *projectionLaneValue, loc); - if (failed(remapped)) - return sourceClass.op->emitError("failed to remap projected batch host sizes"); - sizes.push_back(*remapped); - } - for (OpFoldResult stride : projection->getMixedStrides()) { - FailureOr remapped = - remapProjectionIndexLike(state, sourceClass.op, stride, *sourceLaneArg, *projectionLaneValue, loc); - if (failed(remapped)) - return sourceClass.op->emitError("failed to remap projected batch host strides"); - strides.push_back(*remapped); - } - - auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); - if (!inParallelOp) - return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); - - auto outputArg = batch.getOutputArgument(ownerIt->second); - if (!outputArg) - return batch.emitOpError("missing host output block argument for projected batch publication"); - - state.hostReplacements[originalOutput] = sourceClass.op->getResult(ownerIt->second); - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - tensor::ParallelInsertSliceOp::create(state.rewriter, loc, payload, *outputArg, offsets, sizes, strides); - return success(); -} - -FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); - -FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { - if (value == laneArg) - return static_cast(lane); - - if (std::optional constant = matchConstantIndexValue(value)) - return *constant; - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return failure(); - - SmallVector operands; - operands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); - if (failed(evaluated)) - return failure(); - operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); - } - - SmallVector results; - if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) - return failure(); - - auto intAttr = dyn_cast(results.front()); - if (!intAttr) - return failure(); - return intAttr.getInt(); -} - -FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { - if (auto attr = llvm::dyn_cast(value)) { - auto intAttr = dyn_cast(attr); - if (!intAttr) - return failure(); - return intAttr.getInt(); - } - return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); -} - -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return failure(); - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return failure(); - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (candidateIndex == resultIndex) - return insert; - } - - return failure(); -} - -FailureOr> -evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { - SmallVector evaluated; - evaluated.reserve(values.size()); - for (OpFoldResult value : values) { - FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); - if (failed(index)) - return failure(); - evaluated.push_back(*index); - } - return evaluated; -} - - -bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane) { - auto producerBatch = dyn_cast_or_null(producer.instance.op); - if (!producerBatch) - return true; - - FailureOr producerProjection = - getBatchResultProjectionInsert(producerBatch, producer.resultIndex); - if (failed(producerProjection)) - return true; - - std::optional producerLaneArg = producerBatch.getLaneArgument(); - std::optional consumerLaneArg = consumerBatch.getLaneArgument(); - if (!producerLaneArg || !consumerLaneArg) - return false; - - SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); - SmallVector loopIterationIndices(match.loops.size(), 0); - - const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { - SmallVector consumerOffsets; - consumerOffsets.reserve(match.offsets.size()); - for (OpFoldResult offset : match.offsets) { - FailureOr evaluated = - evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); - if (failed(evaluated)) - return false; - consumerOffsets.push_back(*evaluated); - } - - uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; - for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { - FailureOr> producerOffsets = - evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); - FailureOr> producerSizes = - evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); - FailureOr> producerStrides = - evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); - if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) - return false; - if (!areAllUnitStrides(*producerStrides)) - return false; - if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) - return true; - } - - return false; - }; - - if (match.loops.empty()) - return consumerSliceFitsOneProducerFragment(); - - const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { - if (loopIndex == match.loops.size()) - return consumerSliceFitsOneProducerFragment(); - - for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { - loopIterationIndices[loopIndex] = iteration; - if (!self(self, loopIndex + 1)) - return false; - } - return true; - }; - - return recurse(recurse, 0); -} - -LogicalResult insertProjectedBatchHostFragment(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - uint32_t lane, - Value payload) { - if (ownerClass.isBatch) - return ownerClass.op->emitError("projected batch host fallback expects a scalar owner class"); - - auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == ownerClass.hostOutputToResultIndex.end()) - return ownerClass.op->emitError("missing host result slot for projected batch host fragment"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return ownerClass.op->emitError("projected batch host fallback expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for materialization"); - - auto laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) - return sourceBatch.emitOpError("missing compute_batch lane argument for host projection"); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); - if (failed(offsets) || failed(sizes) || failed(strides)) - return ownerClass.op->emitError("failed to evaluate batch host projection coordinates"); - - auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); - if (!yieldOp) - return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); - - unsigned hostResultIndex = ownerIt->second; - if (hostResultIndex >= yieldOp.getNumOperands()) - return ownerClass.op->emitError("host result index out of range for projected batch host fragment"); - if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) - return ownerClass.op->emitError("projected batch host fragment expected a full host accumulator tensor") - << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() - << " outputType=" << originalOutput.getType(); - - state.rewriter.setInsertionPoint(yieldOp); - Value updated = tensor::InsertSliceOp::create(state.rewriter, - payload.getLoc(), - payload, - yieldOp.getOperand(hostResultIndex), - ValueRange {}, - ValueRange {}, - ValueRange {}, - *offsets, - *sizes, - *strides) - .getResult(); - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, updated); }); - state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); - return success(); -} - - -LogicalResult emitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - ArrayRef keys, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - if (ownerClass.isBatch) - return ownerClass.op->emitError("projected batch host receive loop expects a scalar owner class"); - if (keys.empty()) - return success(); - if (keys.size() != messages.size()) - return ownerClass.op->emitError("projected batch host receive loop message metadata is inconsistent"); - - auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == ownerClass.hostOutputToResultIndex.end()) - return ownerClass.op->emitError("missing host result slot for projected batch host receive loop"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return ownerClass.op->emitError("projected batch host receive loop expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for receive loop"); - - auto laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) - return sourceBatch.emitOpError("missing compute_batch lane argument for projected host receive loop"); - - auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); - if (!yieldOp) - return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); - - unsigned hostResultIndex = ownerIt->second; - if (hostResultIndex >= yieldOp.getNumOperands()) - return ownerClass.op->emitError("host result index out of range for projected batch host receive loop"); - if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) - return ownerClass.op->emitError("projected batch host receive loop expected a full host accumulator tensor") - << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() - << " outputType=" << originalOutput.getType(); - - unsigned rank = projection->getMixedOffsets().size(); - SmallVector, 4> offsetsByDim(rank); - SmallVector, 4> sizesByDim(rank); - SmallVector, 4> stridesByDim(rank); - for (ProducerKey key : keys) { - if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() - || key.instance.laneCount != 1) - return ownerClass.op->emitError("projected batch host receive loop expects one-lane fragments from one output"); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); - if (failed(offsets) || failed(sizes) || failed(strides)) - return ownerClass.op->emitError("failed to evaluate projected batch host receive loop coordinates"); - if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) - return ownerClass.op->emitError("projected batch host receive loop coordinate rank mismatch"); - - for (unsigned dim = 0; dim < rank; ++dim) { - offsetsByDim[dim].push_back((*offsets)[dim]); - sizesByDim[dim].push_back((*sizes)[dim]); - stridesByDim[dim].push_back((*strides)[dim]); - } - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); - - state.rewriter.setInsertionPoint(yieldOp); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {yieldOp.getOperand(hostResultIndex)}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value channelId = createIndexedChannelId(state, ownerClass.op, messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, ownerClass.op, messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, ownerClass.op, messages, flatIndex, loc); - Value fragment = SpatChannelReceiveOp::create( - state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned dim = 0; dim < rank; ++dim) { - offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); - sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); - strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); - } - - Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) - .getResult(); - yielded.push_back(updated); - return success(); - }); - if (failed(loop)) - return failure(); - markScalarCommunication( - loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitProjectedBatchHostReceiveInsertLoop"); - - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); - state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); - return success(); -} - -std::optional tryEmitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - ArrayRef keys, - Location loc) { - if (keys.empty()) - return success(); - - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(keys.front(), ownerClass.id); - ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - for (size_t runIndex : runIndices) { - PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); - if (run.kind != PackedScalarRunKind::DeferredReceive) - continue; - SmallVector runKeys = flattenPackedScalarRunKeys(run); - if (!llvm::equal(runKeys, keys)) - continue; - return emitProjectedBatchHostReceiveInsertLoop( - state, ownerClass, originalOutput, runKeys, run.fragmentType, run.messages, loc); - } - - return std::nullopt; -} - -FailureOr getLeadingPackedFragmentType(Operation* anchor, Value payload, size_t fragmentCount) { - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) - return failure(); - if (payloadType.getDimSize(0) != static_cast(fragmentCount)) - return failure(); - - SmallVector fragmentShape(payloadType.getShape().begin(), payloadType.getShape().end()); - fragmentShape[0] = 1; - return RankedTensorType::get(fragmentShape, payloadType.getElementType(), payloadType.getEncoding()); -} - -LogicalResult emitScalarPackedProjectedHostSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "packed projected host send loop expects a scalar source"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) - return sourceClass.op->emitError("packed projected host send loop expects a static ranked payload"); - - setInsertionPointForScalarCommunication(state, sourceClass, getMinimumChannelId(messages.channelIds)); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(payloadType.getRank()); - sizes.reserve(payloadType.getRank()); - strides.reserve(payloadType.getRank()); - offsets.push_back(index); - sizes.push_back(state.rewriter.getIndexAttr(1)); - strides.push_back(state.rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { - offsets.push_back(state.rewriter.getIndexAttr(0)); - sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); - strides.push_back(state.rewriter.getIndexAttr(1)); - } - - Value fragment = tensor::ExtractSliceOp::create( - state.rewriter, loc, fragmentType, payload, offsets, sizes, strides) - .getResult(); - Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, fragment); - return success(); - }); - if (failed(loop)) - return failure(); - markScalarCommunication( - loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitScalarPackedProjectedHostSendLoop"); - return success(); -} - -LogicalResult emitScalarPackedProjectedHostLocalInsertLoop(MaterializerState& state, - MaterializedClass& ownerClass, - ArrayRef keys, - Value payload, - Value originalOutput, - RankedTensorType fragmentType, - Location loc) { - if (ownerClass.isBatch) - return ownerClass.op->emitError("packed projected host local insert loop expects a scalar owner class"); - if (keys.empty()) - return success(); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) - return ownerClass.op->emitError("packed projected host local insert loop expects a static ranked payload"); - if (payloadType.getDimSize(0) != static_cast(keys.size())) - return ownerClass.op->emitError("packed projected host local insert loop payload/key count mismatch"); - - auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == ownerClass.hostOutputToResultIndex.end()) - return ownerClass.op->emitError("missing host result slot for packed projected host local insert loop"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return ownerClass.op->emitError("packed projected host local insert loop expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for local insert loop"); - - auto laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) - return sourceBatch.emitOpError("missing compute_batch lane argument for packed projected host local insert loop"); - - auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); - if (!yieldOp) - return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); - - unsigned hostResultIndex = ownerIt->second; - if (hostResultIndex >= yieldOp.getNumOperands()) - return ownerClass.op->emitError("host result index out of range for packed projected host local insert loop"); - if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) - return ownerClass.op->emitError("packed projected host local insert loop expected a full host accumulator tensor") - << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() - << " outputType=" << originalOutput.getType(); - - unsigned rank = projection->getMixedOffsets().size(); - SmallVector, 4> offsetsByDim(rank); - SmallVector, 4> sizesByDim(rank); - SmallVector, 4> stridesByDim(rank); - for (ProducerKey key : keys) { - if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() - || key.instance.laneCount != 1) - return ownerClass.op->emitError("packed projected host local insert loop expects one-lane fragments from one output"); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); - if (failed(offsets) || failed(sizes) || failed(strides)) - return ownerClass.op->emitError("failed to evaluate packed projected host local insert loop coordinates"); - if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) - return ownerClass.op->emitError("packed projected host local insert loop coordinate rank mismatch"); - - for (unsigned dim = 0; dim < rank; ++dim) { - offsetsByDim[dim].push_back((*offsets)[dim]); - sizesByDim[dim].push_back((*sizes)[dim]); - stridesByDim[dim].push_back((*strides)[dim]); - } - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); - - state.rewriter.setInsertionPoint(yieldOp); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {yieldOp.getOperand(hostResultIndex)}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.reserve(payloadType.getRank()); - extractSizes.reserve(payloadType.getRank()); - extractStrides.reserve(payloadType.getRank()); - extractOffsets.push_back(flatIndex); - extractSizes.push_back(state.rewriter.getIndexAttr(1)); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { - extractOffsets.push_back(state.rewriter.getIndexAttr(0)); - extractSizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - } - - Value fragment = tensor::ExtractSliceOp::create( - state.rewriter, loc, fragmentType, payload, extractOffsets, extractSizes, extractStrides) - .getResult(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned dim = 0; dim < rank; ++dim) { - offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); - sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); - strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); - } - - Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) - .getResult(); - yielded.push_back(updated); - return success(); - }); - if (failed(loop)) - return failure(); - - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); - state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); - return success(); -} - -std::optional tryEmitScalarPackedProjectedHostPublication(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { - if (sourceClass.isBatch || keys.size() <= 1) - return std::nullopt; - - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for projected batch output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (ownerClass.isBatch) - return ownerClass.op->emitError( - "projected batch host output reached a batch owner without an explicit batch publication path"); - FailureOr fragmentType = getLeadingPackedFragmentType(sourceClass.op, payload, keys.size()); - if (failed(fragmentType)) - return std::nullopt; - - if (ownerClass.id == sourceClass.id) - return emitScalarPackedProjectedHostLocalInsertLoop( - state, ownerClass, keys, payload, originalOutput, *fragmentType, loc); - - auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); - auto targetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); - if (failed(sourceCpu) || failed(targetCpu)) - return failure(); - - MessageVector messages; - for ([[maybe_unused]] ProducerKey key : keys) - messages.append(state.nextChannelId++, *sourceCpu, *targetCpu); - - if (failed(messages.verify(sourceClass.op))) - return failure(); - - if (failed(emitScalarPackedProjectedHostSendLoop(state, sourceClass, payload, *fragmentType, messages, loc))) - return failure(); - - return emitProjectedBatchHostReceiveInsertLoop( - state, ownerClass, originalOutput, keys, *fragmentType, messages, loc); -} - -void appendPendingProjectedHostReceive(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - ProducerKey key, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(messages.size() == 1 && "pending projected host receive records one message at a time"); - for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { - if (group.originalOutput != originalOutput || group.ownerClassId != ownerClass.id || group.fragmentType != fragmentType) - continue; - group.keys.push_back(key); - group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); - return; - } - - PendingProjectedHostReceiveGroup group { - originalOutput, - ownerClass.id, - fragmentType, - SmallVector{key}, - MessageVector{}, - loc - }; - group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); - state.pendingProjectedHostReceives.push_back(std::move(group)); -} - -LogicalResult flushPendingProjectedHostReceives(MaterializerState& state) { - for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { - if (group.ownerClassId >= state.classes.size()) - return state.func.emitError("pending projected host receive has invalid owner class"); - MaterializedClass& ownerClass = state.classes[group.ownerClassId]; - if (failed(group.messages.verify(ownerClass.op))) - return failure(); - if (group.keys.empty()) - continue; - if (failed(emitProjectedBatchHostReceiveInsertLoop( - state, ownerClass, group.originalOutput, group.keys, group.fragmentType, group.messages, group.loc))) - return failure(); - } - state.pendingProjectedHostReceives.clear(); - return success(); -} - -LogicalResult emitProjectedBatchHostFragment(MaterializerState& state, - MaterializedClass& sourceClass, - ProducerKey key, - Value payload, - Value originalOutput, - Location loc) { - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for projected batch output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - Value ownerPayload = payload; - if (sourceClass.id != ownerClass.id) { - if (ownerClass.isBatch) { - return ownerClass.op->emitError( - "projected batch host fragment reached a batch owner without an explicit batch publication path"); - } - - MessageVector messages; - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); - if (failed(checkedTargetCpu)) - return failure(); - if (!sourceClass.isBatch) { - if (failed(checkedSourceCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - - if (failed(appendSend(state, sourceClass, payload, messages, loc))) - return failure(); - - auto fragmentType = dyn_cast(payload.getType()); - if (!fragmentType) - return sourceClass.op->emitError("projected terminal batch host fragment expects ranked tensor payload"); - appendPendingProjectedHostReceive(state, ownerClass, originalOutput, key, fragmentType, messages, loc); - return success(); - } - else { - ComputeInstance scheduledInstance = getScheduledChunkForLogicalInstance(state, key.instance); - auto sourceCpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); - if (sourceCpuIt == state.schedule.computeToCpuMap.end()) - return sourceClass.op->emitError("missing CPU assignment for projected batch host source"); - - auto localLaneIt = sourceClass.cpuToLane.find(sourceCpuIt->second); - if (localLaneIt == sourceClass.cpuToLane.end()) - return sourceClass.op->emitError("missing local batch lane for projected batch host source"); - - if (failed(checkedSourceCpu = getCheckedCoreId(sourceClass.op, - sourceCpuIt->second, - "projected host source core id"))) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - - auto batch = cast(sourceClass.op); - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("missing lane argument for projected batch host source"); - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value localLane = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, localLaneIt->second); - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); - Value isSourceLane = arith::CmpIOp::create(state.rewriter, loc, arith::CmpIPredicate::eq, *laneArg, localLane); - auto ifOp = scf::IfOp::create(state.rewriter, loc, TypeRange {}, isSourceLane, /*withElseRegion=*/false); - state.rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator()); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, loc); - } - } - - return insertProjectedBatchHostFragment(state, ownerClass, originalOutput, key.instance.laneStart, ownerPayload); -} - -LogicalResult -emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) { - if (!hasLiveExternalUseCached(state, originalOutput)) - return success(); - - if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) - return sourceClass.op->emitError("cannot set projected terminal batch host output through the generic host path"); - - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for live external output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (sourceClass.id == ownerClass.id) - return setHostOutputValue(state, ownerClass, originalOutput, payload); - - if (sourceClass.isBatch) - return sourceClass.op->emitError("batch host publication must be routed through a projection-aware or owning path"); - if (ownerClass.isBatch) - return ownerClass.op->emitError("generic host publication does not support batch host owners"); - - MessageVector messages; - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "host target core id"); - if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - - if (failed(appendSend(state, sourceClass, payload, messages, payload.getLoc()))) - return failure(); - Value ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, payload.getLoc()); - return setHostOutputValue(state, ownerClass, originalOutput, ownerPayload); -} - -LogicalResult emitOutputFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { - if (keys.empty()) - return success(); - - if (!sourceClass.isBatch) { - if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) - return failure(); - - if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { - std::optional loopedHostPublication = - tryEmitScalarPackedProjectedHostPublication(state, sourceClass, keys, payload, originalOutput, loc); - if (loopedHostPublication) - return *loopedHostPublication; - - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return sourceClass.op->emitError("projected terminal batch host output expects one logical lane per fragment"); - if (failed(emitProjectedBatchHostFragment(state, sourceClass, key, payload, originalOutput, loc))) - return failure(); - } - return success(); - } - - return emitHostCommunication(state, sourceClass, payload, originalOutput); - } - - if (!haveSameDestinationClasses(state, keys)) - return sourceClass.op->emitError( - "cannot materialize batched output whose lanes have different destination equivalence classes"); - - if (auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp())) { - if (sourceBatch.getNumResults() != 0 && isTerminalHostBatchOutput(originalOutput, state.oldComputeOps)) { - for (ClassId destinationClass : getDestinationClasses(state, keys.front())) - if (!state.classes[destinationClass].isBatch) - return emitBatchToScalarDestinationDiagnostic(state, sourceClass, keys, originalOutput); - } - } - - for (ClassId destinationClass : getDestinationClasses(state, keys.front())) - if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) - return failure(); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - if (sourceBatch && sourceBatch.getNumResults() != 0 && hasLiveExternalUseCached(state, originalOutput)) { - if (sourceClass.hostOutputToResultIndex.contains(originalOutput)) { - if (failed(emitProjectedBatchHostOutput(state, sourceClass, keys, originalOutput, payload, loc))) - return failure(); - } - else { - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for projected batch output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (ownerClass.isBatch) - return ownerClass.op->emitError( - "projected batch host output reached a batch owner without an explicit batch publication path"); - - if (sourceClass.id != ownerClass.id - && failed(emitClassToClassCommunication(state, sourceClass, ownerClass, keys, payload, loc))) - return failure(); - - std::optional loopedHostPublication = - tryEmitProjectedBatchHostReceiveInsertLoop(state, ownerClass, originalOutput, keys, loc); - if (loopedHostPublication) { - if (failed(*loopedHostPublication)) - return failure(); - } - else { - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return sourceClass.op->emitError("projected batch host output expects one logical lane per fragment"); - - std::optional ownerPayload = state.availableValues.lookup(state, key, ownerClass.id); - if (!ownerPayload) - return ownerClass.op->emitError("failed to recover projected batch host fragment after communication"); - - if (failed(insertProjectedBatchHostFragment( - state, ownerClass, originalOutput, key.instance.laneStart, *ownerPayload))) - return failure(); - } - } - } - } else if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) { - return failure(); - } - - for (ProducerKey key : keys) - state.availableValues.record(key, sourceClass.id, payload); - - return success(); -} - -struct DirectWholeBatchFragment { - ProducerKey key; - Value fragment; -}; - -enum class WholeBatchFragmentSourceKind { - DeferredReceive, - DeferredLocalCompute, - PackedValue, - DirectValue -}; - -struct WholeBatchFragmentGroup { - WholeBatchFragmentSourceKind kind = WholeBatchFragmentSourceKind::DirectValue; - RankedTensorType fragmentType; - SmallVector outputOffsets; - MessageVector messages; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - SmallVector sourceLanes; - Value packed; - RankedTensorType slotPackedType; - SmallVector slotIndices; - SmallVector, 16> directFragments; - SmallVector redundantReceives; -}; - -enum class ProjectedWholeBatchFragmentSourceKind { - DeferredReceive, - PackedValue, - DirectValue -}; - -struct ProjectedWholeBatchDirectFragment { - Value fragment; - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -struct ProjectedWholeBatchFragmentGroup { - ProjectedWholeBatchFragmentSourceKind kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; - RankedTensorType fragmentType; - SmallVector, 4> offsetsByDim; - SmallVector, 4> sizesByDim; - SmallVector, 4> stridesByDim; - MessageVector messages; - SmallVector redundantOps; - Value packed; - RankedTensorType packedSourceType; - SmallVector packedIndices; - SmallVector directFragments; -}; - -struct WholeBatchAssemblyPlan { - RankedTensorType resultType; - int64_t rowsPerLane = 0; - uint32_t batchLaneCount = 0; - uint32_t coveredLaneCount = 0; - - SmallVector coveredLanes; - SmallVector packedRuns; - SmallVector directFragments; -}; - -bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) { - return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0; -} - -bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { - if (laneCount == 0) - return false; - if (laneStart >= plan.coveredLanes.size()) - return false; - - uint32_t laneEnd = std::min(laneStart + laneCount, plan.coveredLanes.size()); - for (uint32_t lane = laneStart; lane < laneEnd; ++lane) - if (plan.coveredLanes[lane] != 0) - return true; - return false; -} - -void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { - assert(laneCount != 0 && "cannot cover an empty whole-batch range"); - assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds"); - - for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) { - if (plan.coveredLanes[lane] != 0) - continue; - plan.coveredLanes[lane] = 1; - ++plan.coveredLaneCount; - } -} - -bool localLaneRangeOverlaps(ArrayRef covered, uint32_t laneStart, uint32_t laneCount) { - if (laneCount == 0) - return false; - if (laneStart >= covered.size()) - return false; - - uint32_t laneEnd = std::min(laneStart + laneCount, covered.size()); - for (uint32_t lane = laneStart; lane < laneEnd; ++lane) - if (covered[lane] != 0) - return true; - return false; -} - -void markLocalLaneRangeCovered(MutableArrayRef covered, uint32_t laneStart, uint32_t laneCount) { - assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds"); - for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) - covered[lane] = 1; -} - -LogicalResult -validateWholeBatchFragmentType(RankedTensorType resultType, RankedTensorType fragmentType, int64_t expectedRows) { - if (!fragmentType.hasStaticShape()) - return failure(); - if (fragmentType.getRank() != resultType.getRank()) - return failure(); - if (fragmentType.getDimSize(0) != expectedRows) - return failure(); - - for (int64_t dim = 1; dim < resultType.getRank(); ++dim) - if (fragmentType.getDimSize(dim) != resultType.getDimSize(dim)) - return failure(); - - return success(); -} - -// ----------------------------------------------------------------------------- -// Packed run tensor assembly helpers. -// ----------------------------------------------------------------------------- - -FailureOr insertFragmentIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value fragment, - Value destination, - OpFoldResult firstOffset, - Location loc) { - return createDim0InsertSliceInClass(state, targetClass, loc, fragment, destination, firstOffset); -} - -FailureOr extractPackedSlotForIndex(MaterializerState& state, - MaterializedClass& targetClass, - Value packed, - RankedTensorType slotPackedType, - Value slotIndex, - Location loc) { - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0)); -} - -SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { - SmallVector keys; - for (const PackedScalarRunSlot& slot : run.slots) - llvm::append_range(keys, slot.keys); - return keys; -} - -bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScalarRunValue& rhs) { - if (lhs.slots.size() != rhs.slots.size()) - return false; - - for (auto [lhsSlot, rhsSlot] : llvm::zip(lhs.slots, rhs.slots)) { - if (lhsSlot.keys.size() != rhsSlot.keys.size()) - return false; - if (!llvm::equal(lhsSlot.keys, rhsSlot.keys)) - return false; - } - - return true; -} - - -bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelReceiveOp receive) { - std::optional channelId = getConstantIndexValue(receive.getChannelId()); - std::optional sourceCoreId = getConstantIndexValue(receive.getSourceCoreId()); - std::optional targetCoreId = getConstantIndexValue(receive.getTargetCoreId()); - if (!channelId || !sourceCoreId || !targetCoreId) - return false; - messages.append(*channelId, static_cast(*sourceCoreId), static_cast(*targetCoreId)); - return true; -} - -PackedScalarRunValue* findDeferredReceiveAlternativeForPackedRun(MaterializerState& state, - const MaterializedClass& targetClass, - const PackedScalarRunValue& run) { - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); - ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - - for (size_t runIndex : runIndices) { - PackedScalarRunValue& candidate = state.availableValues.getPackedRun(runIndex); - if (&candidate == &run || candidate.kind != PackedScalarRunKind::DeferredReceive) - continue; - if (candidate.fragmentType != run.fragmentType) - continue; - if (!packedScalarRunSlotsMatch(candidate, run)) - continue; - return &candidate; - } - - return nullptr; -} - -FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - int64_t itemCount, - IndexedFragmentBuilder buildFragment, - IndexedInsertOffsetBuilder buildOffset, - Location loc) { - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, itemCount); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - Operation* insertionPoint = targetClass.body->getTerminator(); - - state.rewriter.setInsertionPoint(insertionPoint); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - FailureOr fragment = buildFragment(flatIndex); - if (failed(fragment)) - return failure(); - FailureOr offset = buildOffset(flatIndex); - if (failed(offset)) - return failure(); - FailureOr next = - insertFragmentIntoWholeBatch(state, targetClass, *fragment, iterArgs.front(), *offset, loc); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -FailureOr> cloneBatchBodyForLane(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - Value laneValue, - ArrayRef resultIndices, - CloneIndexingContext indexing = {}); - -Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc); -FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, - MaterializedClass& targetClass, - IndexedBatchRunValue& run, - Value runSlotIndex, - Location loc); - -FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc) { - assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); - - SmallVector keys = flattenPackedScalarRunKeys(run); - if (keys.empty()) - return failure(); - FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); - if (failed(packedType)) - return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); - - SmallVector sourceLanes; - sourceLanes.reserve(keys.size()); - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return failure(); - sourceLanes.push_back(key.instance.laneStart); - } - - SmallVector resultIndices {run.resultIndex}; - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {init}, - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - keys.front().instance, - sourceLane, - resultIndices, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); - if (failed(produced) || produced->size() != 1) - return failure(); - - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - run.packed = loop->results.front(); - return run.packed; -} - -LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - WholeBatchAssemblyPlan& plan) { - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); - ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - - for (size_t runIndex : runIndices) { - PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); - - SmallVector runKeys; - SmallVector runCoveredLanes(plan.batchLaneCount, 0); - - for (const PackedScalarRunSlot& slot : run.slots) { - for (ProducerKey fragmentKey : slot.keys) { - if (fragmentKey.instance.op != key.instance.op || fragmentKey.resultIndex != key.resultIndex) - return failure(); - - if (fragmentKey.instance.laneCount == 0) - return failure(); - - if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) - return failure(); - - if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) - return failure(); - - markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount); - runKeys.push_back(fragmentKey); - } - } - - if (runKeys.empty()) - continue; - - plan.packedRuns.push_back(&run); - - for (ProducerKey runKey : runKeys) - recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount); - } - - return success(); -} - -LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - SpatComputeBatch batch, - ProducerKey key, - WholeBatchAssemblyPlan& plan) { - struct CandidateFragment { - ProducerKey key; - Value value; - }; - - uint32_t batchLaneCount = static_cast(batch.getLaneCount()); - if (plan.coveredLaneCount == plan.batchLaneCount) { - return success(); - } - - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); - ArrayRef indexedFragments = - state.availableValues.getExactFragmentsForWholeBatch(lookupKey); - - SmallVector candidates; - candidates.reserve(indexedFragments.size()); - for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) { - ProducerKey candidateKey = record.key; - if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex - || candidateKey.instance.laneCount == 0) - continue; - if (!isTensorValueLocalToMaterializedClass(record.value, targetClass)) - continue; - if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount)) - continue; - - auto fragmentType = dyn_cast(record.value.getType()); - if (!fragmentType) - continue; - - int64_t expectedRows = plan.rowsPerLane * static_cast(candidateKey.instance.laneCount); - if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) - continue; - - candidates.push_back({candidateKey, record.value}); - } - - llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) { - if (lhs.key.instance.laneStart != rhs.key.instance.laneStart) - return lhs.key.instance.laneStart < rhs.key.instance.laneStart; - return lhs.key.instance.laneCount > rhs.key.instance.laneCount; - }); - - size_t candidateCursor = 0; - uint32_t lane = 0; - while (lane < batchLaneCount) { - while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) { - ++lane; - } - - if (lane >= batchLaneCount) - break; - - while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane) - ++candidateCursor; - - size_t candidateIndex = candidateCursor; - const CandidateFragment* best = nullptr; - while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) { - const CandidateFragment& candidate = candidates[candidateIndex]; - if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) { - best = &candidate; - break; - } - ++candidateIndex; - } - - if (!best) - return failure(); - - plan.directFragments.push_back({best->key, best->value}); - recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount); - lane += best->key.instance.laneCount; - } - - return success(); -} - -LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, - MaterializedClass& targetClass, - const WholeBatchAssemblyPlan& plan, - SmallVectorImpl& groups) { - for (PackedScalarRunValue* run : plan.packedRuns) { - if (!run || run->slots.empty()) - continue; - if (run->fragmentType.getDimSize(0) != plan.rowsPerLane) - return failure(); - - if (run->kind == PackedScalarRunKind::Materialized && run->packed - && !isTensorValueLocalToMaterializedClass(run->packed, targetClass)) { - if (PackedScalarRunValue* deferredRun = findDeferredReceiveAlternativeForPackedRun(state, targetClass, *run)) - run = deferredRun; - else { - SmallVector keys = flattenPackedScalarRunKeys(*run); - std::optional packedKey = getContiguousProducerRangeForKeys(keys); - emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, - targetClass, - "whole-batch assembly tried to reuse non-local PackedValue", - run->packed, - packedKey); - return failure(); - } - } - - if (run->kind == PackedScalarRunKind::DeferredReceive) { - if (failed(validatePackedScalarRunMetadata(targetClass.op, *run))) - return failure(); - - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == run->fragmentType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DeferredReceive; - group.fragmentType = run->fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - - groupIt->messages.append(run->messages.channelIds, run->messages.sourceCoreIds, run->messages.targetCoreIds); - for (const PackedScalarRunSlot& slot : run->slots) - for (ProducerKey fragmentKey : slot.keys) - groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); - continue; - } - - if (run->kind == PackedScalarRunKind::DeferredLocalCompute) { - SmallVector keys = flattenPackedScalarRunKeys(*run); - if (keys.empty()) - return failure(); - - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DeferredLocalCompute - && group.fragmentType == run->fragmentType && group.sourceOp == run->sourceOp - && group.resultIndex == run->resultIndex; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DeferredLocalCompute; - group.fragmentType = run->fragmentType; - group.sourceOp = run->sourceOp; - group.resultIndex = run->resultIndex; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - - for (ProducerKey fragmentKey : keys) { - if (fragmentKey.instance.laneCount != 1) - return failure(); - groupIt->sourceLanes.push_back(fragmentKey.instance.laneStart); - groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); - } - continue; - } - - auto sourceBatch = dyn_cast_or_null(run->sourceOp); - if (!sourceBatch || !run->packed) - return failure(); - - auto getOrCreatePackedValueGroup = [&](RankedTensorType slotPackedType) -> WholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType - && group.packed == run->packed && group.slotPackedType == slotPackedType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::PackedValue; - group.fragmentType = run->fragmentType; - group.packed = run->packed; - group.slotPackedType = slotPackedType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - size_t flattenedIndexBase = 0; - for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { - std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); - if (contiguousKey) { - FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slot.keys.size()); - if (failed(slotPackedType)) - return failure(); - WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType); - group.slotIndices.push_back(slotIndex); - group.outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); - flattenedIndexBase += slot.keys.size(); - continue; - } - - WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType); - for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) { - group.slotIndices.push_back(flattenedIndexBase + keyIndex); - group.outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); - } - flattenedIndexBase += slot.keys.size(); - } - } - - auto getOrCreateDeferredReceiveGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DeferredReceive; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - auto getOrCreateDirectValueGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DirectValue; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - for (const DirectWholeBatchFragment& fragment : plan.directFragments) { - if (!isTensorValueLocalToMaterializedClass(fragment.fragment, targetClass)) { - emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, - targetClass, - "whole-batch assembly tried to reuse non-local DirectValue", - fragment.fragment, - fragment.key); - return failure(); - } - - auto fragmentType = dyn_cast(fragment.fragment.getType()); - if (!fragmentType) - return failure(); - - int64_t outputOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; - - if (auto receive = fragment.fragment.getDefiningOp()) { - if (fragment.fragment.use_empty()) { - WholeBatchFragmentGroup& group = getOrCreateDeferredReceiveGroup(fragmentType); - if (appendConstantChannelReceiveMessage(group.messages, receive)) { - group.outputOffsets.push_back(outputOffset); - group.redundantReceives.push_back(receive.getOperation()); - continue; - } - } - } - - WholeBatchFragmentGroup& group = getOrCreateDirectValueGroup(fragmentType); - group.directFragments.push_back({fragment.fragment, outputOffset}); - } - - return success(); -} - -FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchFragmentGroup& group, - Location loc) { - switch (group.kind) { - case WholeBatchFragmentSourceKind::DeferredReceive: { - FailureOr updated = emitIndexedFragmentInsertLoop( - state, - targetClass, - destination, - static_cast(group.outputOffsets.size()), - [&](Value flatIndex) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); - return SpatChannelReceiveOp::create( - state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - [&](Value flatIndex) -> FailureOr { - return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); - }, - loc); - if (failed(updated)) - return failure(); - - for (Operation* receive : group.redundantReceives) - if (receive && receive->use_empty()) - receive->erase(); - - return *updated; - } - case WholeBatchFragmentSourceKind::DeferredLocalCompute: { - SmallVector resultIndices {group.resultIndex}; - return emitIndexedFragmentInsertLoop( - state, - targetClass, - destination, - static_cast(group.outputOffsets.size()), - [&](Value flatIndex) -> FailureOr { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, group.sourceLanes, flatIndex, loc); - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - ComputeInstance {group.sourceOp, 0, 1}, - sourceLane, - resultIndices, - CloneIndexingContext {.runSlotIndex = flatIndex, .projectionSlotIndex = flatIndex}); - if (failed(produced) || produced->size() != 1) - return failure(); - return produced->front(); - }, - [&](Value flatIndex) -> FailureOr { - return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); - }, - loc); - } - case WholeBatchFragmentSourceKind::PackedValue: - return emitIndexedFragmentInsertLoop( - state, - targetClass, - destination, - static_cast(group.slotIndices.size()), - [&](Value flatIndex) -> FailureOr { - Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); - FailureOr packed = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - group.packed, - targetClass.op, - "whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); - if (failed(packed)) - return failure(); - return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc); - }, - [&](Value flatIndex) -> FailureOr { - return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); - }, - loc); - case WholeBatchFragmentSourceKind::DirectValue: - for (const auto& [fragment, offset] : group.directFragments) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fragment, - targetClass.op, - "whole-batch direct fragment assembly tried to reuse a tensor from another materialized class"); - if (failed(localFragment)) - return failure(); - FailureOr updated = createDim0InsertSliceInClass(state, - targetClass, - loc, - *localFragment, - destination, - getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); - if (failed(updated)) - return failure(); - destination = *updated; - } - return destination; - } - - return failure(); -} - -FailureOr emitProjectedWholeBatchFragmentInsertLoop( - MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const ProjectedWholeBatchFragmentGroup& group, - llvm::function_ref(Value)> buildFragment, - Location loc) { - assert(group.fragmentType && "expected projected fragment type"); - assert(!group.offsetsByDim.empty() && "expected projected insert coordinates"); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - FailureOr fragment = buildFragment(flatIndex); - if (failed(fragment)) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - unsigned rank = group.offsetsByDim.size(); - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned dim = 0; dim < rank; ++dim) { - offsets.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.offsetsByDim[dim], flatIndex, loc)); - sizes.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.sizesByDim[dim], flatIndex, loc)); - strides.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.stridesByDim[dim], flatIndex, loc)); - } - - Value updated = - tensor::InsertSliceOp::create(state.rewriter, loc, *fragment, iterArgs.front(), offsets, sizes, strides) - .getResult(); - yielded.push_back(updated); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -std::optional getStaticProjectedPackedFragmentIndex(tensor::ExtractSliceOp extract) { - auto sourceType = dyn_cast(extract.getSource().getType()); - auto resultType = dyn_cast(extract.getResult().getType()); - if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() - || sourceType.getRank() == 0 || sourceType.getRank() != resultType.getRank()) - return std::nullopt; - - std::optional firstOffset = getConstantIndex(extract.getMixedOffsets().front()); - if (!firstOffset) - return std::nullopt; - - for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - std::optional offset = getConstantIndex(extract.getMixedOffsets()[dim]); - std::optional size = getConstantIndex(extract.getMixedSizes()[dim]); - std::optional stride = getConstantIndex(extract.getMixedStrides()[dim]); - if (!offset || !size || !stride || *stride != 1 || *size != resultType.getDimSize(dim)) - return std::nullopt; - if (dim != 0 && *offset != 0) - return std::nullopt; - } - - return *firstOffset; -} - -void appendProjectedInsertCoordinates(ProjectedWholeBatchFragmentGroup& group, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - if (group.offsetsByDim.empty()) { - size_t rank = offsets.size(); - group.offsetsByDim.resize(rank); - group.sizesByDim.resize(rank); - group.stridesByDim.resize(rank); - } - - for (size_t dim = 0; dim < offsets.size(); ++dim) { - group.offsetsByDim[dim].push_back(offsets[dim]); - group.sizesByDim[dim].push_back(sizes[dim]); - group.stridesByDim[dim].push_back(strides[dim]); - } -} - -FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType) { - auto batch = dyn_cast_or_null(key.instance.op); - auto resultTensorType = dyn_cast(resultType); - if (!batch || !resultTensorType || !resultTensorType.hasStaticShape() || resultTensorType.getRank() == 0) - return failure(); - - uint32_t batchLaneCount = static_cast(batch.getLaneCount()); - if (batchLaneCount == 0 || resultTensorType.getDimSize(0) % static_cast(batchLaneCount) != 0) - return failure(); - - WholeBatchAssemblyPlan plan; - plan.resultType = resultTensorType; - plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast(batchLaneCount); - plan.batchLaneCount = batchLaneCount; - plan.coveredLanes.assign(batchLaneCount, 0); - - if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) - return failure(); - - if (plan.coveredLaneCount == plan.batchLaneCount) - return plan; - - if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) - return failure(); - - return plan; -} - -FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - WholeBatchAssemblyPlan& plan, - Location loc) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value result = - tensor::EmptyOp::create(state.rewriter, loc, plan.resultType.getShape(), plan.resultType.getElementType()) - .getResult(); - - SmallVector groups; - if (failed(collectWholeBatchFragmentGroups(state, targetClass, plan, groups))) - return failure(); - - for (const WholeBatchFragmentGroup& group : groups) { - FailureOr updated = emitWholeBatchFragmentGroup(state, targetClass, result, group, loc); - if (failed(updated)) - return failure(); - result = *updated; - } - - state.availableValues.record(key, targetClass.id, result); - return result; -} - -// ----------------------------------------------------------------------------- -// Run materialization helpers. -// ----------------------------------------------------------------------------- - -FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType, - Location loc) { - auto batch = dyn_cast_or_null(key.instance.op); - auto resultTensorType = dyn_cast(resultType); - if (!batch || !resultTensorType || !resultTensorType.hasStaticShape()) - return failure(); - - FailureOr projection = getBatchResultProjectionInsert(batch, key.resultIndex); - if (failed(projection)) - return failure(); - - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("missing compute_batch lane argument while materializing projected whole-batch input"); - - uint32_t laneEnd = key.instance.laneStart + key.instance.laneCount; - if (laneEnd > static_cast(batch.getLaneCount())) - return failure(); - - if (targetClass.isBatch) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value result = - tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) - .getResult(); - - for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { - ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); - std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); - if (!fragment) - return failure(); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); - if (failed(offsets) || failed(sizes) || failed(strides)) - return failure(); - - SmallVector offsetAttrs; - SmallVector sizeAttrs; - SmallVector strideAttrs; - offsetAttrs.reserve(offsets->size()); - sizeAttrs.reserve(sizes->size()); - strideAttrs.reserve(strides->size()); - for (auto [offset, size, stride] : llvm::zip(*offsets, *sizes, *strides)) { - offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); - sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); - strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); - } - - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - *fragment, - targetClass.op, - "projected whole-batch assembly tried to reuse a tensor from another materialized class", - laneKey); - if (failed(localFragment)) - return failure(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - result = tensor::InsertSliceOp::create( - state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) - .getResult(); - } - - state.availableValues.record(key, targetClass.id, result); - return result; - } - - SmallVector groups; - auto getOrCreateReceiveGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { - return group.kind == ProjectedWholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - ProjectedWholeBatchFragmentGroup group; - group.kind = ProjectedWholeBatchFragmentSourceKind::DeferredReceive; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - auto getOrCreatePackedGroup = [&](Value packed, - RankedTensorType packedSourceType, - RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { - return group.kind == ProjectedWholeBatchFragmentSourceKind::PackedValue && group.fragmentType == fragmentType - && group.packed == packed && group.packedSourceType == packedSourceType; - }); - if (groupIt == groups.end()) { - ProjectedWholeBatchFragmentGroup group; - group.kind = ProjectedWholeBatchFragmentSourceKind::PackedValue; - group.fragmentType = fragmentType; - group.packed = packed; - group.packedSourceType = packedSourceType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - auto getOrCreateDirectGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { - return group.kind == ProjectedWholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - ProjectedWholeBatchFragmentGroup group; - group.kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { - ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); - if (failed(offsets) || failed(sizes) || failed(strides)) - return failure(); - - bool grouped = false; - if (std::optional exact = state.availableValues.lookupExact(laneKey, targetClass.id)) { - if (auto receive = exact->getDefiningOp()) { - auto fragmentType = dyn_cast(receive.getOutput().getType()); - if (fragmentType && receive.getOutput().use_empty()) { - ProjectedWholeBatchFragmentGroup& group = getOrCreateReceiveGroup(fragmentType); - if (appendConstantChannelReceiveMessage(group.messages, receive)) { - appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); - group.redundantOps.push_back(receive.getOperation()); - grouped = true; - } - } - } - } - - if (grouped) - continue; - - std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); - if (!fragment) - return failure(); - - auto fragmentType = dyn_cast(fragment->getType()); - if (!fragmentType) - return failure(); - - if (auto extract = fragment->getDefiningOp()) { - if (std::optional packedIndex = getStaticProjectedPackedFragmentIndex(extract)) { - auto packedSourceType = dyn_cast(extract.getSource().getType()); - if (packedSourceType) { - ProjectedWholeBatchFragmentGroup& group = - getOrCreatePackedGroup(extract.getSource(), packedSourceType, fragmentType); - group.packedIndices.push_back(*packedIndex); - appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); - group.redundantOps.push_back(extract.getOperation()); - continue; - } - } - } - - ProjectedWholeBatchFragmentGroup& group = getOrCreateDirectGroup(fragmentType); - group.directFragments.push_back({*fragment, *offsets, *sizes, *strides}); - } - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value result = - tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) - .getResult(); - - for (const ProjectedWholeBatchFragmentGroup& group : groups) { - FailureOr updated = failure(); - switch (group.kind) { - case ProjectedWholeBatchFragmentSourceKind::DeferredReceive: - updated = emitProjectedWholeBatchFragmentInsertLoop( - state, - targetClass, - result, - group, - [&](Value flatIndex) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); - return SpatChannelReceiveOp::create( - state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - loc); - break; - case ProjectedWholeBatchFragmentSourceKind::PackedValue: - updated = emitProjectedWholeBatchFragmentInsertLoop( - state, - targetClass, - result, - group, - [&](Value flatIndex) -> FailureOr { - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.reserve(group.packedSourceType.getRank()); - extractSizes.reserve(group.packedSourceType.getRank()); - extractStrides.reserve(group.packedSourceType.getRank()); - extractOffsets.push_back(createIndexedOrStaticIndex( - state, targetClass.op, group.packedIndices, flatIndex, loc)); - extractSizes.push_back(state.rewriter.getIndexAttr(1)); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < group.packedSourceType.getRank(); ++dim) { - extractOffsets.push_back(state.rewriter.getIndexAttr(0)); - extractSizes.push_back(state.rewriter.getIndexAttr(group.packedSourceType.getDimSize(dim))); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - } - - FailureOr packed = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - group.packed, - targetClass.op, - "projected whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); - if (failed(packed)) - return failure(); - - return tensor::ExtractSliceOp::create( - state.rewriter, - loc, - group.fragmentType, - *packed, - extractOffsets, - extractSizes, - extractStrides) - .getResult(); - }, - loc); - break; - case ProjectedWholeBatchFragmentSourceKind::DirectValue: { - updated = result; - for (const ProjectedWholeBatchDirectFragment& fragment : group.directFragments) { - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fragment.fragment, - targetClass.op, - "projected whole-batch assembly tried to reuse a tensor from another materialized class"); - if (failed(localFragment)) - return failure(); - - SmallVector offsetAttrs; - SmallVector sizeAttrs; - SmallVector strideAttrs; - for (auto [offset, size, stride] : llvm::zip(fragment.offsets, fragment.sizes, fragment.strides)) { - offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); - sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); - strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); - } - updated = tensor::InsertSliceOp::create( - state.rewriter, loc, *localFragment, *updated, offsetAttrs, sizeAttrs, strideAttrs) - .getResult(); - } - break; - } - } - if (failed(updated)) - return failure(); - result = *updated; - } - - for (const ProjectedWholeBatchFragmentGroup& group : groups) - for (Operation* redundantOp : group.redundantOps) - if (redundantOp && redundantOp->use_empty()) - redundantOp->erase(); - - state.availableValues.record(key, targetClass.id, result); - return result; -} - -FailureOr materializeWholeBatchInput( - MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { - if (failed(materializePendingScalarReceivesForWholeBatchInput(state, targetClass, key, loc))) - return failure(); - - FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); - if (succeeded(plan)) - return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); - - return materializeProjectedWholeBatchInputFromFragments(state, targetClass, key, resultType, loc); -} - -FailureOr resolveInputValue(MaterializerState& state, - MaterializedClass& targetClass, - Value input, - const ComputeInstance& consumerInstance, - CloneIndexingContext indexing) { - auto rejectNonLocalResolvedValue = [&](Value resolved) -> FailureOr { - if (!isTensorValueDefinedInDifferentMaterializedClass(resolved, targetClass)) - return resolved; - - std::optional producer = getInputRequestProducerKey(input, consumerInstance); - emitNonLocalMaterializedClassValueDiagnostic(consumerInstance.op, - targetClass, - "input resolution tried to reuse a tensor from another materialized class", - resolved, - producer); - return failure(); - }; - - if (isConstantLike(input)) - return input; - - if (std::optional producer = getInputRequestProducerKey(input, consumerInstance)) { - if (indexing.runSlotIndex) { - if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { - FailureOr received = materializeIndexedBatchRunReceive( - state, targetClass, *indexedRun, *indexing.runSlotIndex, consumerInstance.op->getLoc()); - if (failed(received)) - return failure(); - return rejectNonLocalResolvedValue(*received); - } - } - - if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) - return rejectNonLocalResolvedValue(*value); - - if (auto pendingReceive = lookupPendingScalarReceiveIndex(state, *producer, targetClass.id)) { - FailureOr received = - materializePendingScalarReceive(state, targetClass, *pendingReceive, consumerInstance.op->getLoc()); - if (failed(received)) - return failure(); - return rejectNonLocalResolvedValue(*received); - } - - if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { - size_t laneCount = targetClass.cpus.size(); - for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { - if (!llvm::is_contained(slot.keys, *producer)) - continue; - - MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); - Value received = - appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); - for (ProducerKey slotKey : slot.keys) - state.availableValues.record(slotKey, targetClass.id, received); - return rejectNonLocalResolvedValue(received); - } - } - - if (isWholeBatchProducerKey(*producer)) { - FailureOr wholeBatch = - materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); - if (failed(wholeBatch)) - consumerInstance.op->emitError("failed to materialize whole-batch input") - << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart - << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - if (failed(wholeBatch)) - return failure(); - return rejectNonLocalResolvedValue(*wholeBatch); - } - - consumerInstance.op->emitError("failed to resolve producer value") - << " from op '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart - << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - return failure(); - } - - if (isTensorValueDefinedInDifferentMaterializedClass(input, targetClass)) { - emitNonLocalMaterializedClassValueDiagnostic( - consumerInstance.op, - targetClass, - "input resolution tried to append a tensor from another materialized class as a normal input", - input); - return failure(); - } - - return appendInput(state, targetClass, input); -} - -bool hasProjectedInputReplacement(MaterializerState& state, - SpatComputeBatch batch, - unsigned inputIndex, - ClassId classId) { - std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); - if (!match) - return false; - - auto replacementIt = state.projectedExtractReplacements.find(match->extract.getOperation()); - if (replacementIt == state.projectedExtractReplacements.end()) - return false; - - return replacementIt->second.find(classId) != replacementIt->second.end(); -} - -void mapWeights(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper) { - Operation* op = instance.op; - if (auto compute = dyn_cast(op)) { - for (auto [index, weight] : llvm::enumerate(compute.getWeights())) { - auto weightArg = compute.getWeightArgument(index); - assert(weightArg && "expected compute weight block argument"); - mapper.map(*weightArg, appendWeight(state, targetClass, weight)); - } - return; - } - - auto batch = cast(op); - for (auto [index, weight] : llvm::enumerate(batch.getWeights())) { - auto weightArg = batch.getWeightArgument(index); - assert(weightArg && "expected compute_batch weight block argument"); - mapper.map(*weightArg, appendWeight(state, targetClass, weight)); - } -} - -LogicalResult mapInputs(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper, - CloneIndexingContext indexing) { - auto mapResolvedInput = [&](Value resolved) -> FailureOr { - return materializeTensorValueForMaterializedClassUse( - state, - targetClass, - resolved, - targetClass.op, - "input mapping tried to reuse a tensor from another materialized class"); - }; - - Operation* op = instance.op; - if (auto compute = dyn_cast(op)) { - for (auto [index, input] : llvm::enumerate(compute.getInputs())) { - FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); - if (failed(mapped)) { - std::optional producer = getInputRequestProducerKey(input, instance); - auto diagnostic = compute.emitOpError("failed to resolve materialized compute input") << " #" << index; - if (producer) { - diagnostic << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart - << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - } - return failure(); - } - auto inputArg = compute.getInputArgument(index); - if (!inputArg) - return compute.emitOpError("expected compute input block argument while materializing inputs"); - FailureOr remapped = mapResolvedInput(*mapped); - if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic(compute, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); - return failure(); - } - mapper.map(*inputArg, *remapped); - } - return success(); - } - - auto batch = cast(op); - for (auto [index, input] : llvm::enumerate(batch.getInputs())) { - if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) - continue; - - FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); - if (failed(mapped)) - return batch.emitOpError("failed to resolve materialized compute_batch input"); - auto inputArg = batch.getInputArgument(index); - if (!inputArg) - return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); - FailureOr remapped = mapResolvedInput(*mapped); - if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic(batch, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); - return failure(); - } - mapper.map(*inputArg, *remapped); - } - return success(); -} - -SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMapping& mapper) { - SmallVector outputs(batch.getNumResults(), Value {}); - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return outputs; - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return outputs; - unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (resultIndex >= outputs.size()) - continue; - outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource()); - } - - return outputs; -} - -SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { - SmallVector types(batch.getNumResults(), Type {}); - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return types; - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return types; - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (resultIndex >= types.size()) - continue; - - types[resultIndex] = insert.getSource().getType(); - } - - return types; -} - -SmallVector& getBatchOutputFragmentTypesCached(MaterializerState& state, SpatComputeBatch batch) { - auto [it, inserted] = state.batchOutputFragmentTypesCache.try_emplace(batch.getOperation(), SmallVector {}); - if (inserted) - it->second = collectBatchOutputFragmentTypes(batch); - return it->second; -} - -ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance) { - auto [it, inserted] = state.computeInstanceOutputsCache.try_emplace(instance, SmallVector {}); - if (inserted) - it->second = getComputeInstanceOutputValues(instance); - return it->second; -} - -std::optional lookupProjectedExtractReplacement(MaterializerState& state, - MaterializedClass& targetClass, - tensor::ExtractSliceOp extract) { - auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation()); - if (replacementIt == state.projectedExtractReplacements.end()) - return std::nullopt; - - auto classIt = replacementIt->second.find(targetClass.id); - if (classIt == replacementIt->second.end()) - return std::nullopt; - - return classIt->second; -} - -bool requiresConstantProjectionSlotIndex(MaterializerState& state, - MaterializedClass& targetClass, - Operation* sourceOp) { - bool requiresConstantIndex = false; - sourceOp->walk([&](tensor::ExtractSliceOp extract) { - if (requiresConstantIndex) - return WalkResult::interrupt(); - - std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, extract); - if (!replacement) - return WalkResult::advance(); - - if (replacement->layout.payloadFragmentCount != replacement->layout.fragmentsPerLogicalSlot) { - requiresConstantIndex = true; - return WalkResult::interrupt(); - } - - return WalkResult::advance(); - }); - return requiresConstantIndex; -} - -LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, - MaterializedClass& targetClass, - Operation& originalOp, - Operation& clonedOp, - CloneIndexingContext indexing, - IRMapping& mapper) { - if (auto originalExtract = dyn_cast(&originalOp)) { - if (std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, originalExtract)) { - auto clonedExtract = dyn_cast(&clonedOp); - if (!clonedExtract) - return targetClass.op->emitError("projected replacement lost extract structure during cloning"); - - state.rewriter.setInsertionPoint(clonedExtract); - FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex, &mapper); - if (failed(projected)) - return failure(); - - clonedExtract.getResult().replaceAllUsesWith(*projected); - state.rewriter.eraseOp(clonedExtract); - return success(); - } - } - - if (originalOp.getNumRegions() != clonedOp.getNumRegions()) - return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned regions"); - - for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { - if (std::distance(originalRegion.begin(), originalRegion.end()) - != std::distance(clonedRegion.begin(), clonedRegion.end())) - return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned blocks"); - - for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { - auto originalIt = originalBlock.begin(); - auto clonedIt = clonedBlock.begin(); - while (originalIt != originalBlock.end() && clonedIt != clonedBlock.end()) { - Operation& originalNestedOp = *originalIt++; - Operation* currentClonedOp = &*clonedIt++; - if (failed(applyProjectedExtractReplacementsInClonedOp( - state, targetClass, originalNestedOp, *currentClonedOp, indexing, mapper))) - return failure(); - } - if (originalIt != originalBlock.end() || clonedIt != clonedBlock.end()) - return targetClass.op->emitError("projected replacement traversal found mismatched cloned operations"); - } - } - - return success(); -} - -LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& clonedOp, IRMapping& mapper) { - if (originalOp.getNumRegions() != clonedOp.getNumRegions()) - return clonedOp.emitError("cloned operation has a different number of regions than the source operation"); - - for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { - if (std::distance(originalRegion.begin(), originalRegion.end()) - != std::distance(clonedRegion.begin(), clonedRegion.end())) - return clonedOp.emitError("cloned operation has a different number of blocks than the source operation"); - - for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { - if (originalBlock.getNumArguments() != clonedBlock.getNumArguments()) - return clonedOp.emitError("cloned operation block has a different number of arguments than the source block"); - - for (auto [originalArg, clonedArg] : llvm::zip(originalBlock.getArguments(), clonedBlock.getArguments())) - if (!mapper.contains(originalArg)) - mapper.map(originalArg, clonedArg); - - if (std::distance(originalBlock.begin(), originalBlock.end()) != std::distance(clonedBlock.begin(), clonedBlock.end())) - return clonedOp.emitError("cloned operation block has a different number of operations than the source block"); - - auto originalIt = originalBlock.begin(); - auto clonedIt = clonedBlock.begin(); - while (originalIt != originalBlock.end()) { - Operation& originalNestedOp = *originalIt++; - Operation& clonedNestedOp = *clonedIt++; - if (failed(mapClonedRegionBlockArguments(originalNestedOp, clonedNestedOp, mapper))) - return failure(); - } - } - } - - return success(); -} - -LogicalResult cloneComputeTemplateBody(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper, - CloneIndexingContext indexing) { - Block& sourceBlock = getComputeInstanceTemplateBlock(instance); - for (Operation& op : sourceBlock.without_terminator()) { - if (auto extract = dyn_cast(&op)) { - if (std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, extract)) { - FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, extract, *replacement, indexing.projectionSlotIndex, &mapper); - if (failed(projected)) - return failure(); - - mapper.map(extract.getResult(), *projected); - continue; - } - } - - for (Value operand : op.getOperands()) { - if (mapper.contains(operand)) - continue; - - FailureOr localized = localizeMaterializedClassOperand( - state, - targetClass, - operand, - &op, - "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", - "cloneComputeTemplateBody produced an unsupported external non-tensor operand", - &mapper); - if (failed(localized)) - return failure(); - if (*localized != operand) - mapper.map(operand, *localized); - } - - Operation* cloned = state.rewriter.clone(op, mapper); - if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) - return failure(); - if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) - return failure(); - if (op.getNumRegions() != 0 - && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing, mapper))) - return failure(); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(oldResult, newResult); - } - - return success(); -} - -FailureOr materializeProjectedExtractReplacement(MaterializerState& state, - MaterializedClass& targetClass, - tensor::ExtractSliceOp extract, - const ProjectedExtractReplacement& replacement, - std::optional projectionSlotIndex, - IRMapping* mapper) { - if (failed(verifyProjectedFragmentLayout(targetClass.op, replacement.layout))) - return failure(); - - FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - replacement.payload, - targetClass.op, - "projected extract replacement tried to reuse a tensor from another materialized class", - std::nullopt, - mapper); - if (failed(localizedPayload)) - return failure(); - Value payload = *localizedPayload; - - if (replacement.layout.payloadFragmentCount == 1) - return payload; - - if (replacement.layout.payloadFragmentCount < replacement.layout.fragmentsPerLogicalSlot) - return targetClass.op->emitError("projected replacement payload is smaller than one logical slot"); - - Value intraSlotFragmentIndex = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - const auto linearizeProjectedLoopIndices = [&]() -> FailureOr { - if (replacement.layout.loopTripCounts.empty()) - return intraSlotFragmentIndex; - - SmallVector surroundingLoops; - for (Operation* current = extract->getParentOp(); current; current = current->getParentOp()) { - if (auto loop = dyn_cast(current)) - surroundingLoops.push_back(loop); - if (current == targetClass.op) - break; - } - std::reverse(surroundingLoops.begin(), surroundingLoops.end()); - - if (surroundingLoops.size() != replacement.layout.loopTripCounts.size()) - return targetClass.op->emitError("projected replacement loop structure does not match the collected descriptor"); - - Value linearizedIndex = intraSlotFragmentIndex; - for (auto [index, loop] : llvm::enumerate(surroundingLoops)) { - FailureOr localizedIv = - rematerializeIndexValueInClass(state, targetClass, loop.getInductionVar(), extract.getLoc(), mapper); - if (failed(localizedIv)) - return failure(); - Value iv = *localizedIv; - Value lowerBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); - Value tripCount = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]); - - Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult(); - if (replacement.layout.loopSteps[index] != 1) - normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult(); - linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult(); - linearizedIndex = - arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult(); - } - return linearizedIndex; - }; - - FailureOr linearizedIndex = linearizeProjectedLoopIndices(); - if (failed(linearizedIndex)) - return failure(); - intraSlotFragmentIndex = *linearizedIndex; - - const auto computeProjectedPayloadFragmentIndex = [&]() -> FailureOr { - if (replacement.layout.payloadFragmentCount == replacement.layout.fragmentsPerLogicalSlot) { - if (replacement.layout.loopTripCounts.empty() && replacement.layout.fragmentsPerLogicalSlot != 1) - return targetClass.op->emitError("projected replacement is missing loop metadata for packed logical slot"); - return intraSlotFragmentIndex; - } - - if (!projectionSlotIndex) - return targetClass.op->emitError("packed projected extract replacement requires a fragment slot index"); - - FailureOr localProjectionSlotIndex = - rematerializeIndexValueInClass(state, targetClass, *projectionSlotIndex, extract.getLoc(), mapper); - if (failed(localProjectionSlotIndex)) - return failure(); - - Value fragmentsPerLogicalSlot = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); - Value base = - arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot) - .getResult(); - return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); - }; - - FailureOr packedFragmentIndex = computeProjectedPayloadFragmentIndex(); - if (failed(packedFragmentIndex)) - return failure(); - - FailureOr packedOffset = scaleIndexByDim0SizeInClass( - state, targetClass, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); - if (failed(packedOffset)) - return failure(); - return createDim0ExtractSliceInClass( - state, targetClass, extract.getLoc(), payload, *packedOffset, replacement.layout.fragmentType.getDimSize(0)); -} - -FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, - MaterializedClass& targetClass, - IndexedBatchRunValue& run, - Value runSlotIndex, - Location loc) { - if (!targetClass.isBatch) - return targetClass.op->emitError("indexed batch run receive requires a batch target class"); - if (failed(run.messages.verify(targetClass.op))) - return failure(); - - Value flatIndex = createBatchRunFlatIndex(state, targetClass, runSlotIndex, loc); - std::optional preferredPeriod = static_cast(targetClass.cpus.size()); - Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); -} - -LogicalResult localizeCapturesInOperationTree(MaterializerState& state, - MaterializedClass& targetClass, - Operation& root, - StringRef tensorContext, - StringRef genericContext, - IRMapping* mapper = nullptr) { - WalkResult walkResult = root.walk([&](Operation* nestedOp) -> WalkResult { - for (OpOperand& operand : nestedOp->getOpOperands()) { - Value current = operand.get(); - if (isValueLegalInMaterializedClassBody(current, targetClass)) - continue; - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(nestedOp); - FailureOr localized = - localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); - if (failed(localized)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand"); - diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() - << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() - << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) - << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) - << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; - diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; - attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR"); - attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands"); - attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain"); - attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); - attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return WalkResult::interrupt(); - } - operand.set(*localized); - } - return WalkResult::advance(); - }); - - return walkResult.wasInterrupted() ? failure() : success(); -} - -LogicalResult localizeCapturesInClonedOp(MaterializerState& state, - MaterializedClass& targetClass, - Operation& clonedOp, - IRMapping* mapper) { - return localizeCapturesInOperationTree( - state, - targetClass, - clonedOp, - "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", - "cloneComputeTemplateBody produced an unsupported external non-tensor operand", - mapper); -} - -LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass) { - SmallVector bodyOps; - for (Operation& op : *targetClass.body) - op.walk([&](Operation* nestedOp) { bodyOps.push_back(nestedOp); }); - - for (Operation* nestedOp : bodyOps) { - if (nestedOp->getBlock() == nullptr) - continue; - for (OpOperand& operand : nestedOp->getOpOperands()) { - Value current = operand.get(); - if (isValueLegalInMaterializedClassBody(current, targetClass)) - continue; - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(nestedOp); - FailureOr localized = localizeMaterializedClassOperand( - state, - targetClass, - current, - nestedOp, - "final scheduled body capture localization tried to reuse a tensor from another materialized class", - "final scheduled body capture localization found an unsupported external non-tensor operand"); - if (failed(localized)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand"); - diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() - << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() - << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) - << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) - << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; - diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; - attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return failure(); - } - operand.set(*localized); - } - } - - return success(); -} - -FailureOr> cloneInstanceBody(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef peers, - CloneIndexingContext indexing) { - assert(!peers.empty() && "expected at least one peer instance"); - const ComputeInstance& instance = peers.front(); - Operation* sourceOp = instance.op; - Location loc = sourceOp->getLoc(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - IRMapping mapper; - if (auto batch = dyn_cast(sourceOp)) { - for (const ComputeInstance& peer : peers) { - if (peer.op != sourceOp) { - sourceOp->emitError("equivalence class slot contains different source compute_batch operations"); - return failure(); - } - } - auto laneArg = batch.getLaneArgument(); - if (!laneArg) { - sourceOp->emitError("expected source compute_batch lane block argument"); - return failure(); - } - mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc)); - } - - OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); - - mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) - return failure(); - - state.rewriter.restoreInsertionPoint(cloneInsertionPoint); - if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) - return failure(); - - if (auto compute = dyn_cast(sourceOp)) { - Block& sourceBlock = getComputeInstanceTemplateBlock(instance); - auto yield = dyn_cast_or_null(sourceBlock.getTerminator()); - if (!yield) { - compute.emitOpError("expected spat.yield terminator while materializing compute"); - return failure(); - } - - SmallVector outputs; - outputs.reserve(yield.getNumOperands()); - for (Value yielded : yield.getOutputs()) - outputs.push_back(mapper.lookupOrDefault(yielded)); - return outputs; - } - - auto batch = cast(sourceOp); - if (batch.getNumResults() == 0) - return SmallVector {}; - - SmallVector outputs = collectMappedBatchOutputs(batch, mapper); - for (Value output : outputs) - if (!output) { - batch.emitOpError("failed to recover yielded per-lane value for compute_batch result"); - return failure(); - } - return outputs; -} - -bool sameDestinationClasses(ArrayRef lhs, ArrayRef rhs) { - if (lhs.size() != rhs.size()) - return false; - for (auto [lhsClass, rhsClass] : llvm::zip(lhs, rhs)) - if (lhsClass != rhsClass) - return false; - return true; -} - -SmallVector -collectDestinationClassesForRun(MaterializerState& state, ArrayRef run, size_t resultIndex) { - SmallVector destinations; - - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) { - ProducerKey key {peer, resultIndex}; - for (ClassId destinationClass : getDestinationClasses(state, key)) - if (!llvm::is_contained(destinations, destinationClass)) - destinations.push_back(destinationClass); - } - } - - llvm::sort(destinations); - return destinations; -} - -SmallVector groupBatchRunOutputsByDestination(MaterializerState& state, - ArrayRef run) { - assert(!run.empty() && "expected non-empty materialization run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - - SmallVector groups; - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); - - for (auto [resultIndex, output] : llvm::enumerate(outputs)) { - SmallVector destinations = collectDestinationClassesForRun(state, run, resultIndex); - - auto existingGroup = llvm::find_if(groups, [&](const OutputDestinationGroup& group) { - return sameDestinationClasses(group.destinationClasses, destinations); - }); - - if (existingGroup != groups.end()) { - existingGroup->resultIndices.push_back(resultIndex); - continue; - } - - OutputDestinationGroup group; - group.resultIndices.push_back(resultIndex); - group.destinationClasses = std::move(destinations); - groups.push_back(std::move(group)); - } - - return groups; -} - -FailureOr getPackedRunTensorType(Type elementType, size_t runSize) { - auto tensorType = dyn_cast(elementType); - if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) - return failure(); - - SmallVector shape(tensorType.getShape()); - shape[0] *= static_cast(runSize); - return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); -} - -LogicalResult registerDeferredLocalPackedRunValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef keys, - Type fragmentType, - Location loc) { - if (keys.empty()) - return success(); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return materializedClass.op->emitError("deferred local packed run expects static ranked fragment type"); - - Operation* sourceOp = keys.front().instance.op; - size_t resultIndex = keys.front().resultIndex; - - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != resultIndex) - return materializedClass.op->emitError("deferred local packed run expects one producer result"); - - if (key.instance.laneCount != 1) - return materializedClass.op->emitError("deferred local packed run expects one lane per fragment"); - } - - PackedScalarRunValue packedRun; - packedRun.targetClass = materializedClass.id; - packedRun.sourceOp = sourceOp; - packedRun.resultIndex = resultIndex; - packedRun.kind = PackedScalarRunKind::DeferredLocalCompute; - packedRun.fragmentType = rankedFragmentType; - - packedRun.slots.reserve(keys.size()); - for (ProducerKey key : keys) { - PackedScalarRunSlot slot; - slot.keys.push_back(key); - packedRun.slots.push_back(std::move(slot)); - } - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -LogicalResult registerPackedRunValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef keys, - Value packed, - Type fragmentType, - Location loc) { - if (keys.empty()) - return success(); - - FailureOr expectedPackedType = getPackedRunTensorType(fragmentType, keys.size()); - if (failed(expectedPackedType)) - return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); - - if (packed.getType() != *expectedPackedType) - return materializedClass.op->emitError("packed run value has unexpected tensor type"); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); - - Operation* sourceOp = keys.front().instance.op; - size_t resultIndex = keys.front().resultIndex; - - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != resultIndex) - return materializedClass.op->emitError("packed run registration expects one producer result"); - if (key.instance.laneCount != 1) - return materializedClass.op->emitError("packed run registration expects one lane per packed fragment"); - } - - if (std::optional contiguousKey = getContiguousProducerRangeForKeys(keys)) { - state.availableValues.record(*contiguousKey, materializedClass.id, packed); - return success(); - } - - PackedScalarRunValue packedRun; - packedRun.targetClass = materializedClass.id; - packedRun.sourceOp = sourceOp; - packedRun.resultIndex = resultIndex; - packedRun.packed = packed; - packedRun.kind = PackedScalarRunKind::Materialized; - packedRun.fragmentType = rankedFragmentType; - - packedRun.slots.reserve(keys.size()); - for (ProducerKey key : keys) { - PackedScalarRunSlot slot; - slot.keys.push_back(key); - packedRun.slots.push_back(std::move(slot)); - } - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -LogicalResult emitPackedRunFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef destinationClasses, - ArrayRef keys, - Value packed, - Type fragmentType, - Location loc) { - assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); - - auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, packed); - if (failed(fanoutPlan)) - return failure(); - if (failed(emitScalarSourceFanoutSends(state, sourceClass, packed, *fanoutPlan, loc))) - return failure(); - - for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { - MaterializedClass& targetClass = state.classes[plan.targetClass]; - - Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); - - if (plan.projectedExtractOp) { - state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = - ProjectedExtractReplacement {received, plan.projectedLayout}; - continue; - } - - if (failed(registerPackedRunValue(state, targetClass, keys, received, fragmentType, loc))) - return failure(); - } - - return success(); -} - -FailureOr> cloneBatchBodyForLane(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - Value laneValue, - ArrayRef resultIndices, - CloneIndexingContext indexing) { - auto batch = dyn_cast(instance.op); - if (!batch) - return failure(); - - IRMapping mapper; - auto sourceLaneArg = batch.getLaneArgument(); - if (!sourceLaneArg) - return batch.emitOpError("expected source compute_batch lane block argument"); - - mapper.map(*sourceLaneArg, laneValue); - - OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); - - mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) - return failure(); - - state.rewriter.restoreInsertionPoint(cloneInsertionPoint); - if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) - return failure(); - - SmallVector allOutputs = collectMappedBatchOutputs(batch, mapper); - if (allOutputs.empty() && !resultIndices.empty()) - return batch.emitOpError("failed to recover source compute_batch outputs"); - - SmallVector selectedOutputs; - selectedOutputs.reserve(resultIndices.size()); - for (size_t resultIndex : resultIndices) { - if (resultIndex >= allOutputs.size() || !allOutputs[resultIndex]) - return batch.emitOpError("failed to recover selected compute_batch output"); - selectedOutputs.push_back(allOutputs[resultIndex]); - } - - return selectedOutputs; -} - -FailureOr> materializeBatchOutputGroupLoop(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - const OutputDestinationGroup& group) { - assert(!run.empty() && "expected non-empty batch run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - - Operation* sourceOp = run.front().peers.front().op; - Location loc = sourceOp->getLoc(); - - if (run.size() == 1) { - if (run.front().peers.size() != 1) - return sourceOp->emitError("scalar batch output loop expects exactly one peer in singleton slot"); - - const ComputeInstance& item = run.front().peers.front(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart); - return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {}); - } - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - auto sourceBatch = cast(sourceOp); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); - SmallVector initValues; - for (size_t resultIndex : group.resultIndices) { - if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); - - Type fragmentType = fragmentTypes[resultIndex]; - FailureOr packedType = getPackedRunTensorType(fragmentType, run.size()); - if (failed(packedType)) - return sourceBatch.emitOpError("cannot materialize packed batch run for non-static ranked output"); - - initValues.push_back( - tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); - } - - SmallVector logicalLanes; - logicalLanes.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { - if (slot.peers.size() != 1) - return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); - - const ComputeInstance& item = slot.peers.front(); - if (item.op != sourceOp) - return sourceOp->emitError("materialization run contains different source operations"); - - logicalLanes.push_back(item.laneStart); - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange(initValues), - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - run.front().peers.front(), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); - if (failed(produced)) - return failure(); - - yielded.reserve(produced->size()); - for (auto [outputIndex, output] : llvm::enumerate(*produced)) { - auto fragmentType = cast(output.getType()); - Value acc = iterArgs[outputIndex]; - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, fragmentType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, output, acc, *firstOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); - } - return success(); - }); - if (failed(loop)) - return failure(); - - SmallVector results; - results.reserve(loop->results.size()); - for (Value result : loop->results) - results.push_back(result); - return results; -} - -SmallVector getMaterializationRunSlotOutputKeys(const MaterializationRunSlot& slot, - size_t resultIndex) { - SmallVector keys; - keys.reserve(slot.peers.size()); - for (const ComputeInstance& peer : slot.peers) - keys.push_back({peer, resultIndex}); - return keys; -} - -FailureOr> -getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId logicalSlot) { - if (targetClass.isBatch) - return getPeerLogicalInstances(state, targetClass, logicalSlot); - - auto streamIt = state.logicalInstancesByCpu.find(targetClass.cpus.front()); - if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) - return failure(); - - return SmallVector {streamIt->second[logicalSlot]}; -} - -FailureOr collectBatchMaterializationRun(MaterializerState& state, - MaterializedClass& targetClass, - SlotId startSlot, - Operation* sourceOp) { - MaterializationRun run; - - for (SlotId slot = startSlot;; ++slot) { - ClassSlotKey classSlot {targetClass.id, slot}; - if (state.materializedLogicalSlots.contains(classSlot)) - break; - - FailureOr> peers = getMaterializationRunSlotPeers(state, targetClass, slot); - if (failed(peers) || peers->empty()) - break; - - bool validSlot = true; - for (const ComputeInstance& peer : *peers) { - if (peer.op != sourceOp || !isa(peer.op)) { - validSlot = false; - break; - } - } - - if (!validSlot) - break; - - MaterializationRunSlot runSlot; - runSlot.peers = std::move(*peers); - run.push_back(std::move(runSlot)); - } - - if (run.empty()) - return failure(); - - return run; -} - -SmallVector getMaterializationRunOutputKeys(ArrayRef run, size_t resultIndex) { - SmallVector keys; - for (const MaterializationRunSlot& slot : run) - llvm::append_range(keys, getMaterializationRunSlotOutputKeys(slot, resultIndex)); - return keys; -} - -ArrayRef getFirstMaterializationRunOriginalOutputs(MaterializerState& state, - ArrayRef run) { - assert(!run.empty() && "expected non-empty materialization run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - return getComputeInstanceOutputValuesCached(state, run.front().peers.front()); -} - -Operation* getMaterializationRunSourceOp(ArrayRef run) { - assert(!run.empty() && "expected non-empty materialization run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - return run.front().peers.front().op; -} - -Location getMaterializationRunLoc(ArrayRef run) { - return getMaterializationRunSourceOp(run)->getLoc(); -} - -bool hasMaterializationRunResultLiveExternalUse(MaterializerState& state, - ArrayRef run, - size_t resultIndex) { - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) { - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, peer); - if (resultIndex >= outputs.size()) - return true; - - if (hasLiveExternalUseCached(state, outputs[resultIndex])) - return true; - } - } - - return false; -} - -bool hasMaterializationRunGroupLiveExternalUse(MaterializerState& state, - ArrayRef run, - const OutputDestinationGroup& group) { - for (size_t resultIndex : group.resultIndices) - if (hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) - return true; - - return false; -} - -bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId); - -bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, - ClassId classId, - ArrayRef run, - const OutputDestinationGroup& group) { - for (size_t resultIndex : group.resultIndices) { - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) - if (hasSameClassConsumer(state, {peer, resultIndex}, classId)) - return true; - } - } - - return false; -} - -bool canRegisterDeferredLocalPackedRun(MaterializerState& state, ArrayRef run) { - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) { - for (Value input : getComputeInstanceInputs(peer)) { - std::optional producer = getInputRequestProducerKey(input, peer); - if (producer && isWholeBatchProducerKey(*producer)) - return false; - } - } - } - - return true; -} - -void markMaterializationRunSlots(MaterializerState& state, - ClassId classId, - SlotId startSlot, - ArrayRef run) { - for (auto slotIndex : llvm::seq(0, run.size())) - state.materializedLogicalSlots.insert({classId, startSlot + static_cast(slotIndex)}); -} - -LogicalResult materializeScalarBatchRun(MaterializerState& state, - MaterializedClass& targetClass, - SlotId startSlot, - ArrayRef run) { - assert(!targetClass.isBatch && "scalar batch run materialization expects scalar target class"); - assert(!run.empty() && "expected non-empty batch run"); - - markMaterializationRunSlots(state, targetClass.id, startSlot, run); - - SmallVector groups = groupBatchRunOutputsByDestination(state, run); - ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); - - auto sourceBatch = cast(getMaterializationRunSourceOp(run)); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); - Location loc = getMaterializationRunLoc(run); - bool canDeferLocalPackedRun = canRegisterDeferredLocalPackedRun(state, run); - - for (const OutputDestinationGroup& group : groups) { - bool canUseLocalOnlyPackedRun = run.size() > 1 && group.destinationClasses.empty() - && !hasMaterializationRunGroupLiveExternalUse(state, run, group) - && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group); - if (canUseLocalOnlyPackedRun && canDeferLocalPackedRun) { - for (size_t resultIndex : group.resultIndices) { - if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); - - SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); - if (failed(registerDeferredLocalPackedRunValue(state, targetClass, keys, fragmentTypes[resultIndex], loc))) - return failure(); - } - - continue; - } - - FailureOr> packedOutputs = materializeBatchOutputGroupLoop(state, targetClass, run, group); - if (failed(packedOutputs)) - return failure(); - - for (auto [groupOutputIndex, resultIndex] : llvm::enumerate(group.resultIndices)) { - Value packed = (*packedOutputs)[groupOutputIndex]; - if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); - - Type fragmentType = fragmentTypes[resultIndex]; - SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); - - if (run.size() == 1) { - if (failed(emitOutputFanout(state, targetClass, keys, packed, firstOriginalOutputs[resultIndex], loc))) - return failure(); - continue; - } - - if (canUseLocalOnlyPackedRun) { - if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) - return failure(); - continue; - } - - if (failed(emitPackedRunFanout(state, targetClass, group.destinationClasses, keys, packed, fragmentType, loc))) - return failure(); - - if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) - return failure(); - - Value representativeOutput = firstOriginalOutputs[resultIndex]; - if (hasLiveExternalUseCached(state, representativeOutput) - && isProjectedTerminalBatchHostOutput(representativeOutput, state.oldComputeOps)) { - std::optional groupedHostPublication = - tryEmitScalarPackedProjectedHostPublication(state, targetClass, keys, packed, representativeOutput, loc); - if (groupedHostPublication) { - if (failed(*groupedHostPublication)) - return failure(); - continue; - } - } - - auto rankedFragmentType = cast(fragmentType); - for (auto [runIndex, slot] : llvm::enumerate(run)) { - assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); - - ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, slot.peers.front()); - Value originalOutput = originalOutputs[resultIndex]; - - if (!hasLiveExternalUseCached(state, originalOutput)) - continue; - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - FailureOr fragment = - getPackedSliceForRunIndex(state, targetClass, packed, rankedFragmentType, runIndex, loc); - if (failed(fragment)) - return failure(); - - if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { - ProducerKey key {slot.peers.front(), resultIndex}; - if (failed(emitProjectedBatchHostFragment(state, targetClass, key, *fragment, originalOutput, loc))) - return failure(); - continue; - } - - if (failed(emitHostCommunication(state, targetClass, *fragment, originalOutput))) - return failure(); - } - } - } - - return success(); -} - -bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { - SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId}; - auto it = state.sameClassConsumerIndex.find(lookupKey); - if (it == state.sameClassConsumerIndex.end()) - return false; - - for (ProducerKey existing : it->second) - if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing)) - return true; - return false; -} - -bool canCompactBatchClassRun(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run) { - if (run.size() < 2) - return false; - if (run.front().peers.empty()) - return false; - - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); - - for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { - (void) ignored; - for (const MaterializationRunSlot& slot : run) { - if (slot.peers.empty()) - return false; - - for (const ComputeInstance& peer : slot.peers) { - ArrayRef peerOutputs = getComputeInstanceOutputValuesCached(state, peer); - if (resultIndex >= peerOutputs.size()) - return false; - - Value originalOutput = peerOutputs[resultIndex]; - if (hasLiveExternalUseCached(state, originalOutput)) - return false; - - ProducerKey key {peer, resultIndex}; - if (hasSameClassConsumer(state, key, targetClass.id)) - return false; - } - } - } - - return true; -} - -LogicalResult registerMaterializedBatchRunHostOutputs(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - const OutputDestinationGroup& group) { - ArrayRef originalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); - for (size_t resultIndex : group.resultIndices) { - if (resultIndex >= originalOutputs.size()) - return targetClass.op->emitError("batch materialization host output index out of range"); - - Value originalOutput = originalOutputs[resultIndex]; - if (!hasLiveExternalUseCached(state, originalOutput)) - continue; - - auto resultIt = targetClass.hostOutputToResultIndex.find(originalOutput); - if (resultIt == targetClass.hostOutputToResultIndex.end()) - return targetClass.op->emitError("missing host result slot for materialized batch output"); - - state.hostReplacements[originalOutput] = targetClass.op->getResult(resultIt->second); - } - - return success(); -} - -LogicalResult verifyMaterializedHostOutputs(MaterializerState& state) { - for (SpatCompute compute : state.func.getOps()) { - auto yieldOp = dyn_cast_or_null(compute.getBody().front().getTerminator()); - if (!yieldOp) - return compute.emitOpError("expected spat.yield terminator in materialized compute"); - if (compute.getNumResults() != yieldOp.getNumOperands()) - return compute.emitOpError("materialized compute result count does not match spat.yield operand count"); - for (auto [result, yielded] : llvm::zip(compute.getResults(), yieldOp.getOperands())) - if (result.getType() != yielded.getType()) - return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand"); - } - - for (SpatChannelReceiveOp receive : state.func.getOps()) { - if (!receive.getOutput().use_empty()) - continue; - return receive.emitOpError("materialized channel_receive result must have at least one use"); - } - - for (const MaterializedClass& materializedClass : state.classes) { - if (!materializedClass.isBatch || materializedClass.hostOutputs.empty()) - continue; - - auto batch = dyn_cast(materializedClass.op); - auto inParallel = dyn_cast_or_null(materializedClass.body->getTerminator()); - if (!batch || !inParallel) - return materializedClass.op->emitError("expected resultful materialized compute_batch host owner"); - - for (Value hostOutput : materializedClass.hostOutputs) { - auto ownerIt = materializedClass.hostOutputToResultIndex.find(hostOutput); - if (ownerIt == materializedClass.hostOutputToResultIndex.end()) - return materializedClass.op->emitError("missing host result slot for materialized compute_batch host output"); - - auto outputArg = batch.getOutputArgument(ownerIt->second); - if (!outputArg) - return batch.emitOpError("missing output block argument for materialized compute_batch host output"); - - bool foundProjection = false; - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert || insert.getDest() != *outputArg) - continue; - foundProjection = true; - break; - } - - if (!foundProjection) - return batch.emitOpError( - "materialized terminal compute_batch host output is missing tensor.parallel_insert_slice publication"); - } - } - - for (const auto& [originalOutput, replacement] : state.hostReplacements) - if (originalOutput.getType() != replacement.getType()) - return replacement.getDefiningOp()->emitOpError("host output replacement type does not match original output type") - << " replacementType=" << replacement.getType() << " outputType=" << originalOutput.getType(); - - return success(); -} - -Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { - auto batch = cast(targetClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected materialized compute_batch lane argument"); - - MLIRContext* context = state.func.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - AffineExpr d1 = getAffineDimExpr(1, context); - - int64_t laneCount = static_cast(targetClass.cpus.size()); - AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); - return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); -} - -Value createBatchClassRunSourceLane(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - Value slotIndex, - Location loc) { - SmallVector sourceLanes; - sourceLanes.reserve(run.size() * targetClass.cpus.size()); - - for (auto [runSlotIndex, slot] : llvm::enumerate(run)) { - (void) runSlotIndex; - assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); - for (const ComputeInstance& peer : slot.peers) - sourceLanes.push_back(peer.laneStart); - } - - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - return createIndexedIndexValue(state, - targetClass.op, - sourceLanes, - flatIndex, - loc, - static_cast(targetClass.cpus.size()), - /*allowExhaustiveTiledSearch=*/false); -} - -LogicalResult buildBatchRunSendPlans(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef run, - const OutputDestinationGroup& group, - SmallVectorImpl& plans) { - assert(sourceClass.isBatch && "batch run send planning expects a materialized batch source"); - - for (size_t resultIndex : group.resultIndices) { - for (ClassId destinationClass : group.destinationClasses) { - if (destinationClass == sourceClass.id) - return sourceClass.op->emitError("batch-target run compaction cannot handle same-class consumers"); - - MaterializedClass& targetClass = state.classes[destinationClass]; - - if (targetClass.isBatch && targetClass.cpus.size() != sourceClass.cpus.size()) - return sourceClass.op->emitError( - "cannot compact batch run communication between batch classes of different sizes"); - - BatchRunSendPlan plan; - plan.resultIndex = resultIndex; - plan.destinationClass = destinationClass; - - size_t messageCount = run.size() * sourceClass.cpus.size(); - plan.messages.channelIds.reserve(messageCount); - plan.messages.sourceCoreIds.reserve(messageCount); - plan.messages.targetCoreIds.reserve(messageCount); - - for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); - if (failed(checkedSourceCpu)) - return failure(); - auto checkedTargetCpu = - getCheckedCoreId(targetClass.op, - targetClass.isBatch ? targetClass.cpus[lane] : targetClass.cpus.front(), - "batch run target core id"); - if (failed(checkedTargetCpu)) - return failure(); - plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - } - (void) slotIndex; - } - - plans.push_back(std::move(plan)); - } - } - - return success(); -} - -void appendBatchRunSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const BatchRunSendPlan& plan, - Value flatIndex, - Location loc) { - assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); - - std::optional preferredPeriod = static_cast(sourceClass.cpus.size()); - Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); - - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); -} - -LogicalResult appendPackedScalarRunReceives(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef run, - const BatchRunSendPlan& plan, - Type fragmentType, - Location loc) { - MaterializedClass& targetClass = state.classes[plan.destinationClass]; - assert(!targetClass.isBatch && "packed scalar run receives expect a scalar target class"); - - size_t laneCount = sourceClass.cpus.size(); - size_t receiveCount = run.size() * laneCount; - - if (failed(plan.messages.verify(targetClass.op))) - return failure(); - - if (receiveCount != plan.messages.size()) - return targetClass.op->emitError("inconsistent flattened batch run receive plan"); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return targetClass.op->emitError("packed scalar run receive expects static ranked fragment type"); - - PackedScalarRunValue packedRun; - packedRun.targetClass = targetClass.id; - packedRun.sourceOp = run.front().peers.front().op; - packedRun.resultIndex = plan.resultIndex; - packedRun.kind = PackedScalarRunKind::DeferredReceive; - packedRun.fragmentType = rankedFragmentType; - - packedRun.messages = plan.messages; - - packedRun.slots.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { - PackedScalarRunSlot packedSlot; - packedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); - packedRun.slots.push_back(std::move(packedSlot)); - } - - if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) - return failure(); - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -LogicalResult recordIndexedBatchRunReceives(MaterializerState& state, - ArrayRef run, - const BatchRunSendPlan& plan, - Type fragmentType) { - MaterializedClass& targetClass = state.classes[plan.destinationClass]; - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return targetClass.op->emitError("indexed batch run receive expects static ranked fragment type"); - - IndexedBatchRunValue indexedRun; - indexedRun.targetClass = targetClass.id; - indexedRun.sourceOp = run.front().peers.front().op; - indexedRun.resultIndex = plan.resultIndex; - indexedRun.fragmentType = rankedFragmentType; - indexedRun.messages = plan.messages; - indexedRun.slots.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { - PackedScalarRunSlot indexedSlot; - indexedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); - indexedRun.slots.push_back(std::move(indexedSlot)); - } - - state.availableValues.recordIndexedBatchRun(std::move(indexedRun)); - return success(); -} - -LogicalResult appendBatchRunReceives(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef run, - const BatchRunSendPlan& plan, - Type fragmentType, - Location loc) { - MaterializedClass& targetClass = state.classes[plan.destinationClass]; - - if (!targetClass.isBatch) - return appendPackedScalarRunReceives(state, sourceClass, run, plan, fragmentType, loc); - return recordIndexedBatchRunReceives(state, run, plan, fragmentType); -} - -LogicalResult materializeBatchClassRun(MaterializerState& state, - MaterializedClass& targetClass, - SlotId startSlot, - ArrayRef run) { - assert(targetClass.isBatch && "batch-target run materialization expects a materialized batch class"); - assert(!run.empty() && "expected non-empty batch-target run"); - - if (!canCompactBatchClassRun(state, targetClass, run)) - return failure(); - - markMaterializationRunSlots(state, targetClass.id, startSlot, run); - - SmallVector groups = groupBatchRunOutputsByDestination(state, run); - - auto sourceBatch = cast(run.front().peers.front().op); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); - Location loc = sourceBatch.getLoc(); - bool constantProjectionSlotIndex = requiresConstantProjectionSlotIndex(state, targetClass, sourceBatch); - - for (const OutputDestinationGroup& group : groups) { - SmallVector sendPlans; - if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) - return failure(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - if (constantProjectionSlotIndex) { - for (auto [slotIndex, slot] : llvm::enumerate(run)) { - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - Value slotIndexValue = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(slotIndex)); - Value sourceLane = getOrCreateIndexConstant(state.constantFolder, targetClass.op, slot.peers.front().laneStart); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndexValue, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - getScheduledChunkForLogicalInstance(state, run.front().peers.front()), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndexValue, - .projectionSlotIndex = slotIndexValue}); - if (failed(produced)) - return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); - if (resultIt == group.resultIndices.end()) - return failure(); - - size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); - appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); - } - } - } else { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { - Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - getScheduledChunkForLogicalInstance(state, run.front().peers.front()), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); - if (failed(produced)) - return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); - if (resultIt == group.resultIndices.end()) - return failure(); - - size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); - appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); - } - return success(); - }); - if (failed(loop)) - return failure(); - } - - for (const BatchRunSendPlan& plan : sendPlans) { - if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) - return failure(); - - if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) - return failure(); - } - - if (failed(registerMaterializedBatchRunHostOutputs(state, targetClass, run, group))) - return failure(); - } - - return success(); -} - -LogicalResult materializeInstanceSlot(MaterializerState& state, - const ComputeInstance& instance) { - auto cpuIt = state.schedule.computeToCpuMap.find(instance); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - auto logicalRangeIt = state.scheduledInstanceToLogicalSlots.find(instance); - if (logicalRangeIt == state.scheduledInstanceToLogicalSlots.end()) - return instance.op->emitError("schedule materialization expected logical slots for every compute instance"); - - ClassId classId = state.cpuToClass.lookup(cpuIt->second); - MaterializedClass& targetClass = state.classes[classId]; - - LogicalSlotRange logicalRange = logicalRangeIt->second; - SlotId startLogicalSlot = logicalRange.start; - while (startLogicalSlot < logicalRange.start + logicalRange.count - && state.materializedLogicalSlots.contains({classId, startLogicalSlot})) { - ++startLogicalSlot; - } - if (startLogicalSlot == logicalRange.start + logicalRange.count) - return success(); - - if (isa(instance.op)) { - FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); - - if (succeeded(run)) { - if (!targetClass.isBatch) - return materializeScalarBatchRun(state, targetClass, startLogicalSlot, *run); - - if (succeeded(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run))) - return success(); - } - } - - if (!state.materializedLogicalSlots.insert({classId, startLogicalSlot}).second) - return success(); - - FailureOr> peers = - getMaterializationRunSlotPeers(state, targetClass, startLogicalSlot); - if (failed(peers)) - return instance.op->emitError("failed to collect peer compute instances for equivalence class logical slot"); - - Value projectionSlotIndex = getOrCreateIndexConstant( - state.constantFolder, targetClass.op, static_cast(startLogicalSlot - logicalRange.start)); - FailureOr> materializedOutputs = - cloneInstanceBody(state, - targetClass, - *peers, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = projectionSlotIndex}); - if (failed(materializedOutputs)) - return failure(); - - ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, instance); - if (materializedOutputs->size() != originalOutputs.size()) - return instance.op->emitError("materialized output count does not match original compute instance output count"); - - for (auto [resultIndex, zipped] : llvm::enumerate(llvm::zip(*materializedOutputs, originalOutputs))) { - Value materializedOutput = std::get<0>(zipped); - Value originalOutput = std::get<1>(zipped); - MaterializationRunSlot slot; - slot.peers = *peers; - SmallVector keys = getMaterializationRunSlotOutputKeys(slot, resultIndex); - if (failed(emitOutputFanout(state, targetClass, keys, materializedOutput, originalOutput, instance.op->getLoc()))) - return failure(); - } - - return success(); -} - -FailureOr createReceiveConcatLoop(MaterializerState& state, - MaterializedClass& targetClass, - RankedTensorType concatType, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one receive"); - - Operation* insertionPoint = targetClass.body->getTerminator(); - state.rewriter.setInsertionPoint(insertionPoint); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); - return emitIndexedFragmentInsertLoop( - state, - targetClass, - init, - static_cast(messages.size()), - [&](Value index) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc); - return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - [&](Value index) -> FailureOr { - return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc); - }, - loc); -} - - -std::optional getDirectCommunicationOrderKey(Operation* op) { - if (!op) - return std::nullopt; - - Value channelId; - Value sourceCoreId; - Value targetCoreId; - if (auto send = dyn_cast(op)) { - channelId = send.getChannelId(); - sourceCoreId = send.getSourceCoreId(); - targetCoreId = send.getTargetCoreId(); - } - else if (auto receive = dyn_cast(op)) { - channelId = receive.getChannelId(); - sourceCoreId = receive.getSourceCoreId(); - targetCoreId = receive.getTargetCoreId(); - } - else { - return std::nullopt; - } - - auto channel = getConstantIndexValue(channelId); - auto source = getConstantIndexValue(sourceCoreId); - auto target = getConstantIndexValue(targetCoreId); - if (!channel || !source || !target) - return std::nullopt; - - return computeBlockingCommunicationOrderKey( - static_cast(*source), static_cast(*target), *channel); -} - -std::optional getScalarCommunicationOrderKey(Operation* op) { - if (!op) - return std::nullopt; - if (auto order = op->getAttrOfType(kRaptorCommOrderAttr)) - return order.getInt(); - if (auto directOrder = getDirectCommunicationOrderKey(op)) - return directOrder; - if (auto channel = op->getAttrOfType(kRaptorMinChannelIdAttr)) - return channel.getInt(); - return std::nullopt; -} - -bool isReorderableScalarCommunication(Operation* op) { - if (!getScalarCommunicationOrderKey(op).has_value()) - return false; - - // The global-order repair is intentionally conservative: it may reorder - // send-side projections, but it must not move receives or any other - // communication op that defines SSA values. Moving a receive after one of - // its users breaks MLIR dominance; moving it before the source can produce - // the payload can also create a receive/receive deadlock. Receives therefore - // have to be placed correctly by the materializer when they are created. - // Direct spat.channel_send operations are included even when they were not - // produced by appendScalarSendLoop and therefore do not carry raptor.* - // attributes yet. This is needed for large scalar-to-scalar payload transfers - // that must be hoisted before reciprocal receives. - return isa(op) || (op->getNumResults() == 0 && op->hasAttr(kRaptorMinChannelIdAttr)); -} - -Operation* getLaterOperationInBlock(Operation* lhs, Operation* rhs) { - if (!lhs) - return rhs; - if (!rhs) - return lhs; - return lhs->isBeforeInBlock(rhs) ? rhs : lhs; -} - -Operation* getNextInsertionPointAfter(Operation* op, Block& block) { - if (!op) - return &block.front(); - Operation* next = op->getNextNode(); - return next ? next : block.getTerminator(); -} - -bool hasConstantRoutingOperands(SpatChannelSendOp send) { - return getConstantIndexValue(send.getChannelId()).has_value() - && getConstantIndexValue(send.getSourceCoreId()).has_value() - && getConstantIndexValue(send.getTargetCoreId()).has_value(); -} - -Operation* getLatestSameBlockOperandDefinition(Operation* root, Block& block) { - Operation* latest = nullptr; - - auto consider = [&](Value value) { - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != &block || definingOp == root) - return; - latest = getLaterOperationInBlock(latest, definingOp); - }; - - // For direct sends with constant routing operands, only the payload is a real - // scheduling dependency. The channel/source/target constants can be - // rematerialized at the new insertion point. Treating those constants as hard - // dependencies prevents the repair from hoisting a ready send above an early - // receive, which is exactly the receive/receive deadlock pattern reported by - // the static communication checker. - if (auto send = dyn_cast(root)) { - if (hasConstantRoutingOperands(send)) { - consider(send.getInput()); - return latest; - } - } - - for (Value operand : root->getOperands()) - consider(operand); - - for (Region& region : root->getRegions()) { - region.walk([&](Operation* nested) { - if (nested == root) - return; - for (Value operand : nested->getOperands()) - consider(operand); - }); - } - - return latest; -} - -void rematerializeDirectSendRoutingConstantsAt(MaterializerState& state, - SpatChannelSendOp send, - Operation* insertionPoint) { - if (!send || !insertionPoint || !hasConstantRoutingOperands(send)) - return; - - auto channel = getConstantIndexValue(send.getChannelId()); - auto source = getConstantIndexValue(send.getSourceCoreId()); - auto target = getConstantIndexValue(send.getTargetCoreId()); - if (!channel || !source || !target) - return; - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(insertionPoint); - Location loc = send.getLoc(); - Value newChannel = arith::ConstantIndexOp::create(state.rewriter, loc, *channel); - Value newSource = arith::ConstantIndexOp::create(state.rewriter, loc, *source); - Value newTarget = arith::ConstantIndexOp::create(state.rewriter, loc, *target); - send->setOperand(0, newChannel); - send->setOperand(1, newSource); - send->setOperand(2, newTarget); -} - -LogicalResult reorderScalarClassCommunicationByGlobalOrder(MaterializerState& state, - MaterializedClass& materializedClass) { - if (materializedClass.isBatch) - return success(); - - Block& block = *materializedClass.body; - Operation* terminator = block.getTerminator(); - SmallVector communicationOps; - for (Operation& op : block) { - if (&op == terminator) - break; - if (isReorderableScalarCommunication(&op)) - communicationOps.push_back(&op); - } - - if (communicationOps.size() < 2) - return success(); - - llvm::stable_sort(communicationOps, [](Operation* lhs, Operation* rhs) { - std::optional lhsOrder = getScalarCommunicationOrderKey(lhs); - std::optional rhsOrder = getScalarCommunicationOrderKey(rhs); - if (lhsOrder != rhsOrder) - return lhsOrder.value_or(std::numeric_limits::max()) - < rhsOrder.value_or(std::numeric_limits::max()); - return lhs->isBeforeInBlock(rhs); - }); - - Operation* lastPlacedCommunication = nullptr; - for (Operation* communication : communicationOps) { - if (communication->getBlock() != &block) - return materializedClass.op->emitError("scalar communication global-order repair saw a moved operation"); - - Operation* dependency = getLatestSameBlockOperandDefinition(communication, block); - Operation* anchor = getLaterOperationInBlock(lastPlacedCommunication, dependency); - Operation* insertionPoint = getNextInsertionPointAfter(anchor, block); - - if (insertionPoint != communication && communication->getNextNode() != insertionPoint) { - if (auto send = dyn_cast(communication)) - rematerializeDirectSendRoutingConstantsAt(state, send, insertionPoint); - communication->moveBefore(insertionPoint); - } - - lastPlacedCommunication = communication; - } - - return success(); -} - -LogicalResult reorderScalarCommunicationsByGlobalOrder(MaterializerState& state) { - for (MaterializedClass& materializedClass : state.classes) - if (failed(reorderScalarClassCommunicationByGlobalOrder(state, materializedClass))) - return failure(); - return success(); -} - - -Operation* getEarliestOperationInBlock(Operation* lhs, Operation* rhs) { - if (!lhs) - return rhs; - if (!rhs) - return lhs; - return lhs->isBeforeInBlock(rhs) ? lhs : rhs; -} - -Operation* getTopLevelOperationInBlock(Operation* op, Block& block) { - for (Operation* current = op; current; current = current->getParentOp()) { - if (current->getBlock() == &block) - return current; - } - return nullptr; -} - -Operation* findEarliestTopLevelUse(Operation* producer, Block& block) { - Operation* earliest = nullptr; - for (Value result : producer->getResults()) { - for (Operation* user : result.getUsers()) { - Operation* topLevelUser = getTopLevelOperationInBlock(user, block); - if (!topLevelUser || topLevelUser == producer) - continue; - earliest = getEarliestOperationInBlock(earliest, topLevelUser); - } - } - return earliest; -} - -LogicalResult sinkScalarReceivesToFirstUse(MaterializerState& state) { - for (MaterializedClass& materializedClass : state.classes) { - if (materializedClass.isBatch) - continue; - - Block& block = *materializedClass.body; - Operation* terminator = block.getTerminator(); - SmallVector receives; - for (Operation& op : block) { - if (&op == terminator) - break; - if (isa(&op)) - receives.push_back(&op); - } - - for (Operation* receive : receives) { - if (receive->getBlock() != &block) - continue; - - Operation* firstUse = findEarliestTopLevelUse(receive, block); - if (!firstUse || firstUse == receive || firstUse->getBlock() != &block) - continue; - - if (!receive->isBeforeInBlock(firstUse)) - continue; - - if (receive->getNextNode() == firstUse) - continue; - - receive->setAttr("raptor.receive_sunk_to_first_use", UnitAttr::get(receive->getContext())); - receive->moveBefore(firstUse); - } - } - return success(); -} - -void replaceHostUses(MaterializerState& state) { - for (const auto& [oldValue, replacement] : state.hostReplacements) - replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); -} - -LogicalResult eraseOldComputeOps(MaterializerState& state) { - DenseSet seen; - for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { - if (!seen.insert(instance.op).second) - continue; - instance.op->dropAllUses(); - instance.op->erase(); - } - return success(); -} - -} // namespace - -LogicalResult -MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { - if (schedule.dominanceOrderCompute.empty()) - return success(); - - MaterializerState state(func, schedule, nextChannelId); - if (failed(buildMaterializationWorkStreams(state))) - return failure(); - if (failed(buildMaterializationClassesFromScheduleEquivalence(state))) - return failure(); - if (failed(verifyScheduleEquivalenceMatchesLogicalStreams(state))) - return failure(); - if (state.classes.empty()) - return success(); - - if (failed(collectHostOutputs(state))) - return failure(); - if (failed(createEmptyMaterializedOps(state))) - return failure(); - if (failed(collectProducerDestinations(state))) - return failure(); - if (failed(collectProjectedTransfers(state))) - return failure(); - - for (const ComputeInstance& instance : schedule.dominanceOrderCompute) - if (failed(materializeInstanceSlot(state, instance))) - return failure(); - - for (MaterializedClass& materializedClass : state.classes) - if (failed(localizeAllScheduledBodyCaptures(state, materializedClass))) - return failure(); - - if (failed(flushPendingProjectedHostReceives(state))) - return failure(); - - if (pimMaterializeScalarFanoutGlobalOrder) { - if (failed(sinkScalarReceivesToFirstUse(state))) - return failure(); - if (failed(reorderScalarCommunicationsByGlobalOrder(state))) - return failure(); - } - - if (failed(verifyMaterializedHostOutputs(state))) - return failure(); - - replaceHostUses(state); - if (failed(eraseOldComputeOps(state))) - return failure(); - - LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); - (void) _; - - return success(); -} - -} // namespace spatial -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig deleted file mode 100644 index 1244543..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig +++ /dev/null @@ -1,7548 +0,0 @@ -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/RegionUtils.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include -#include -#include -#include - -#include "MaterializeMergeSchedule.hpp" -#include "Scheduling/ComputeInstanceUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { -namespace spatial { -namespace { - -using CpuId = size_t; -using ClassId = size_t; -using SlotId = size_t; - -static FailureOr getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) { - return pim::checkedI32(static_cast(cpu), anchor, fieldName); -} - -static FailureOr> -getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) { - SmallVector coreIds; - coreIds.reserve(cpus.size()); - for (CpuId cpu : cpus) { - auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); - if (failed(checkedCoreId)) - return failure(); - coreIds.push_back(*checkedCoreId); - } - return coreIds; -} - -struct MessageVector { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - size_t size() const { return channelIds.size(); } - bool empty() const { return channelIds.empty(); } - - LogicalResult verify(Operation* anchor) const { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) - return anchor->emitError("message metadata is inconsistent"); - return success(); - } - - void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { - channelIds.push_back(channelId); - sourceCoreIds.push_back(sourceCoreId); - targetCoreIds.push_back(targetCoreId); - } - - void append(ArrayRef channels, ArrayRef sources, ArrayRef targets) { - assert(channels.size() == sources.size() && "channel/source count mismatch"); - assert(channels.size() == targets.size() && "channel/target count mismatch"); - llvm::append_range(channelIds, channels); - llvm::append_range(sourceCoreIds, sources); - llvm::append_range(targetCoreIds, targets); - } - - MessageVector slice(size_t offset, size_t count) const { - MessageVector result; - result.append(ArrayRef(channelIds).slice(offset, count), - ArrayRef(sourceCoreIds).slice(offset, count), - ArrayRef(targetCoreIds).slice(offset, count)); - return result; - } -}; - -struct ProducerKey { - ComputeInstance instance; - size_t resultIndex = 0; - - bool operator==(const ProducerKey& other) const { - return instance == other.instance && resultIndex == other.resultIndex; - } -}; - -struct ProducerKeyInfo { - static ProducerKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; - } - - static ProducerKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ProducerKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); - } - - static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } -}; - -struct SameClassConsumerLookupKey { - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - ClassId classId = 0; - - bool operator==(const SameClassConsumerLookupKey& other) const { - return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; - } -}; - -struct SameClassConsumerLookupKeyInfo { - static SameClassConsumerLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static SameClassConsumerLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static unsigned getHashValue(const SameClassConsumerLookupKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); - } - - static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { - return lhs == rhs; - } -}; - -struct WholeBatchAssemblyLookupKey { - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - ClassId classId = 0; - - bool operator==(const WholeBatchAssemblyLookupKey& other) const { - return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; - } -}; - -struct WholeBatchAssemblyLookupKeyInfo { - static WholeBatchAssemblyLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static WholeBatchAssemblyLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), - std::numeric_limits::max()}; - } - - static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { - return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); - } - - static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { - return lhs == rhs; - } -}; - -using ClassSlotKey = std::pair; - -struct MaterializedClass { - ClassId id = 0; - SmallVector cpus; - Operation* op = nullptr; - Block* body = nullptr; - bool isBatch = false; - - DenseMap cpuToLane; - SmallVector weights; - SmallVector inputs; - SmallVector hostOutputs; - DenseMap weightArgs; - DenseMap inputArgs; - DenseMap hostOutputToResultIndex; -}; - -struct PackedScalarRunSlot { - SmallVector keys; -}; - -enum class PackedScalarRunKind { - Materialized, - DeferredReceive, - DeferredLocalCompute -}; - -struct PackedScalarRunValue { - ClassId targetClass = 0; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - PackedScalarRunKind kind = PackedScalarRunKind::Materialized; - - Value packed; - - RankedTensorType fragmentType; - SmallVector slots; - MessageVector messages; -}; - -struct IndexedBatchRunValue { - ClassId targetClass = 0; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - RankedTensorType fragmentType; - SmallVector slots; - MessageVector messages; -}; - -struct LogicalSlotRange { - SlotId start = 0; - SlotId count = 0; -}; - -struct MaterializationRunSlot { - SmallVector peers; -}; - -using MaterializationRun = SmallVector; - -struct OutputDestinationGroup { - SmallVector resultIndices; - SmallVector destinationClasses; -}; - -struct BatchRunSendPlan { - size_t resultIndex = 0; - ClassId destinationClass = 0; - MessageVector messages; -}; - -struct ProjectedBatchInputKey { - Operation* consumerOp = nullptr; - unsigned inputIndex = 0; - - bool operator==(const ProjectedBatchInputKey& other) const { - return consumerOp == other.consumerOp && inputIndex == other.inputIndex; - } -}; - -struct ProjectedBatchInputKeyInfo { - static ProjectedBatchInputKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; - } - - static ProjectedBatchInputKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ProjectedBatchInputKey& key) { - return llvm::hash_combine(key.consumerOp, key.inputIndex); - } - - static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } -}; - -struct ProjectedFragmentLayout { - RankedTensorType fragmentType; - SmallVector fragmentShape; - unsigned fragmentsPerLogicalSlot = 1; - unsigned payloadFragmentCount = 1; - SmallVector loopLowerBounds; - SmallVector loopSteps; - SmallVector loopTripCounts; -}; - -struct ProjectedTransferDescriptor { - ProjectedBatchInputKey inputKey; - Operation* extractOp = nullptr; - - ProjectedFragmentLayout layout; - RankedTensorType payloadType; - SmallVector, 16> fragmentOffsets; - SmallVector, 4> fragmentOffsetsByDim; -}; - -struct ProjectedExtractReplacement { - Value payload; - ProjectedFragmentLayout layout; -}; - -struct PendingProjectedHostOutputFragment { - Value originalOutput; - ClassId sourceClass = 0; - Value fragment; - RankedTensorType fragmentType; - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - uint32_t sourceLane = 0; - Location loc; - - // When a materialized batch class is the source, the send must be emitted - // once with lane-indexed channel metadata. Finalization then only emits the - // matching scalar receive for each recorded fragment. Scalar sources keep - // the old behavior and emit their send during finalization. - bool sendAlreadyEmitted = false; - MessageVector messages; -}; - -struct CloneIndexingContext { - std::optional runSlotIndex; - std::optional projectionSlotIndex; -}; - -struct StaticProjectedLoopInfo { - BlockArgument iv; - int64_t lowerBound = 0; - int64_t step = 1; - int64_t tripCount = 1; -}; - -struct AffineProjectedInputSliceMatch { - tensor::ExtractSliceOp extract; - RankedTensorType sourceType; - RankedTensorType fragmentType; - SmallVector fragmentShape; - SmallVector offsets; - SmallVector loops; -}; - -struct MaterializerState; - -FailureOr materializeProjectedExtractReplacement(MaterializerState& state, - MaterializedClass& targetClass, - tensor::ExtractSliceOp extract, - const ProjectedExtractReplacement& replacement, - std::optional projectionSlotIndex, - IRMapping* mapper = nullptr); -FailureOr rematerializeTensorValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - IRMapping* mapper = nullptr); -FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - std::optional producer = std::nullopt, - IRMapping* mapper = nullptr); -FailureOr localizeMaterializedClassOperand(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef tensorContext, - StringRef genericContext, - IRMapping* mapper = nullptr); -LogicalResult localizeCapturesInClonedOp(MaterializerState& state, - MaterializedClass& targetClass, - Operation& clonedOp, - IRMapping* mapper = nullptr); -LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass); -bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane); -std::optional getProjectedInputSliceMatch(MaterializerState& state, - SpatComputeBatch batch, - unsigned inputIndex); - -class AvailableValueStore { -public: - struct ExactBatchFragmentRecord { - ProducerKey key; - Value value; - }; - - void record(ProducerKey key, ClassId classId, Value value) { - exactValues[key][classId] = value; - - auto batch = dyn_cast_or_null(key.instance.op); - if (!batch || key.instance.laneCount == 0) - return; - - WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; - SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; - for (ExactBatchFragmentRecord& record : bucket) { - if (!(record.key == key)) - continue; - record.value = value; - return; - } - bucket.push_back({key, value}); - } - - void recordPackedRun(PackedScalarRunValue run) { - size_t runIndex = packedScalarRuns.size(); - packedScalarRuns.push_back(std::move(run)); - const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; - WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; - packedRunsByProducerResultClass[lookupKey].push_back(runIndex); - } - void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } - - std::optional lookupExact(ProducerKey key, ClassId classId) const; - - std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); - IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); - - ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { - auto it = packedRunsByProducerResultClass.find(key); - if (it == packedRunsByProducerResultClass.end()) - return {}; - return it->second; - } - - ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { - auto it = exactBatchFragmentsByProducerResultClass.find(key); - if (it == exactBatchFragmentsByProducerResultClass.end()) - return {}; - return it->second; - } - - PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } - -private: - std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); - - DenseMap, ProducerKeyInfo> exactValues; - SmallVector packedScalarRuns; - SmallVector indexedBatchRuns; - DenseMap, WholeBatchAssemblyLookupKeyInfo> - exactBatchFragmentsByProducerResultClass; - DenseMap, WholeBatchAssemblyLookupKeyInfo> - packedRunsByProducerResultClass; -}; - -struct MaterializerState { - func::FuncOp func; - const MergeScheduleResult& schedule; - IRRewriter rewriter; - OperationFolder constantFolder; - int64_t& nextChannelId; - SmallVector classes; - DenseMap cpuToClass; - DenseMap> logicalInstancesByCpu; - DenseMap scheduledInstanceToLogicalSlots; - DenseMap logicalInstanceToScheduledChunk; - DenseSet materializedLogicalSlots; - - DenseMap, ProducerKeyInfo> producerDestClasses; - DenseMap, SameClassConsumerLookupKeyInfo> - sameClassConsumerIndex; - DenseMap projectedInputMatches; - DenseSet nonProjectedInputs; - DenseMap liveExternalUseCache; - DenseMap> batchOutputFragmentTypesCache; - DenseMap, llvm::DenseMapInfo> computeInstanceOutputsCache; - DenseMap, ProducerKeyInfo> projectedTransfers; - DenseMap> projectedExtractReplacements; - AvailableValueStore availableValues; - DenseMap hostReplacements; - DenseMap hostOutputOwners; - SmallVector pendingProjectedHostOutputFragments; - DenseSet oldComputeOps; - - MaterializerState(func::FuncOp func, - const MergeScheduleResult& schedule, - int64_t& nextChannelId) - : func(func), - schedule(schedule), - rewriter(func.getContext()), - constantFolder(func.getContext()), - nextChannelId(nextChannelId) {} -}; - -bool isConstantLike(Value value) { - Operation* definingOp = value.getDefiningOp(); - return definingOp && definingOp->hasTrait(); -} - -bool isInsideOldCompute(Operation* op, const DenseSet& oldComputeOps) { - for (Operation* current = op; current; current = current->getParentOp()) - if (oldComputeOps.contains(current)) - return true; - return false; -} - -bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps); -ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance); - -bool hasLiveExternalUseCached(MaterializerState& state, Value value) { - auto cached = state.liveExternalUseCache.find(value); - if (cached != state.liveExternalUseCache.end()) - return cached->second; - bool live = hasLiveExternalUse(value, state.oldComputeOps); - state.liveExternalUseCache[value] = live; - return live; -} - -std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extract) { - if (extract.getMixedOffsets().empty()) - return std::nullopt; - - OpFoldResult offset = extract.getMixedOffsets().front(); - if (auto attr = dyn_cast(offset)) { - auto intAttr = dyn_cast(attr); - if (!intAttr || intAttr.getInt() < 0) - return std::nullopt; - return static_cast(intAttr.getInt()); - } - - auto value = cast(offset); - if (auto constantIndex = value.getDefiningOp()) { - if (constantIndex.value() < 0) - return std::nullopt; - return static_cast(constantIndex.value()); - } - - APInt constantValue; - if (matchPattern(value, m_ConstantInt(&constantValue))) { - if (constantValue.isNegative()) - return std::nullopt; - return static_cast(constantValue.getZExtValue()); - } - - return std::nullopt; -} - -ProducerKey -getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) { - return { - {batch.getOperation(), laneStart, laneCount}, - resultIndex - }; -} - -ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { - return getBatchLaneProducerKey(batch, 0, static_cast(batch.getLaneCount()), resultIndex); -} - -bool isWholeBatchProducerKey(ProducerKey key) { - auto batch = dyn_cast_or_null(key.instance.op); - return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 - && key.instance.laneCount == static_cast(batch.getLaneCount()); -} - -std::optional getContiguousProducerRangeForKeys(ArrayRef keys) { - if (keys.empty()) - return std::nullopt; - - ProducerKey first = keys.front(); - auto batch = dyn_cast_or_null(first.instance.op); - if (!batch) - return std::nullopt; - - SmallVector sorted(keys.begin(), keys.end()); - llvm::sort(sorted, [](ProducerKey lhs, ProducerKey rhs) { - return std::tie(lhs.instance.laneStart, lhs.instance.laneCount, lhs.resultIndex) - < std::tie(rhs.instance.laneStart, rhs.instance.laneCount, rhs.resultIndex); - }); - - uint32_t laneStart = sorted.front().instance.laneStart; - uint32_t nextLane = laneStart; - for (ProducerKey key : sorted) { - if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) - return std::nullopt; - if (key.instance.laneStart != nextLane) - return std::nullopt; - nextLane += key.instance.laneCount; - } - - uint32_t laneCount = nextLane - laneStart; - if (laneStart + laneCount > static_cast(batch.getLaneCount())) - return std::nullopt; - - return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); -} - -WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) { - return {sourceOp, resultIndex, classId}; -} - -WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) { - return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId); -} - -FailureOr getPackedBatchTensorType(Type laneType, size_t laneCount) { - auto tensorType = dyn_cast(laneType); - if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) - return failure(); - - SmallVector shape(tensorType.getShape()); - shape[0] *= static_cast(laneCount); - return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); -} - -LogicalResult verifyPackableFragmentType(Operation* anchor, Type fragmentType, size_t count, StringRef message) { - if (failed(getPackedBatchTensorType(fragmentType, count))) - return anchor->emitError(message); - return success(); -} - -ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, ComputeInstance logicalInstance) { - auto it = state.logicalInstanceToScheduledChunk.find(logicalInstance); - if (it != state.logicalInstanceToScheduledChunk.end()) - return it->second; - return logicalInstance; -} - -SmallVector -collectProducerKeysForDestinations(Value value, std::optional logicalConsumer = std::nullopt) { - // Destination collection works in the materializer's logical one-lane key domain. - // Whole-batch resultful producers are expanded into per-lane producer keys here. - SmallVector keys; - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return keys; - - while (auto extract = dyn_cast(definingOp)) { - Value source = extract.getSource(); - auto batch = dyn_cast_or_null(source.getDefiningOp()); - if (batch && batch.getNumResults() != 0) { - auto result = dyn_cast(source); - if (!result) - return {}; - - if (std::optional lane = getConstantFirstSliceOffset(extract)) { - if (*lane >= static_cast(batch.getLaneCount())) - return {}; - keys.push_back(getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber())); - return keys; - } - - if (logicalConsumer && isa(logicalConsumer->op)) { - keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); - return keys; - } - - return {}; - } - - value = source; - definingOp = value.getDefiningOp(); - if (!definingOp) - return {}; - } - - if (auto compute = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return {}; - keys.push_back({ - {compute.getOperation(), 0, 1}, - result.getResultNumber() - }); - return keys; - } - - if (auto batch = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return {}; - - if (batch.getNumResults() != 0) { - if (logicalConsumer && isa(logicalConsumer->op)) { - keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); - return keys; - } - - for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) - keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); - return keys; - } - - ComputeInstance chunk = getBatchChunkForLane(batch, result.getResultNumber()); - keys.push_back({chunk, static_cast(result.getResultNumber() - chunk.laneStart)}); - return keys; - } - - return keys; -} - -std::optional getInputRequestProducerKey(Value value, - std::optional logicalConsumer = std::nullopt) { - // Input resolution may request a whole-batch key for scalar consumers that read - // a complete resultful compute_batch value. - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return std::nullopt; - - while (auto extract = dyn_cast(definingOp)) { - Value source = extract.getSource(); - auto batch = dyn_cast_or_null(source.getDefiningOp()); - if (batch && batch.getNumResults() != 0) { - auto result = dyn_cast(source); - if (!result) - return std::nullopt; - - if (std::optional lane = getConstantFirstSliceOffset(extract)) - return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); - - if (logicalConsumer && isa(logicalConsumer->op)) - return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); - - return std::nullopt; - } - - value = source; - definingOp = value.getDefiningOp(); - if (!definingOp) - return std::nullopt; - } - - if (auto compute = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return std::nullopt; - return ProducerKey { - {compute.getOperation(), 0, 1}, - result.getResultNumber() - }; - } - - if (auto batch = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return std::nullopt; - - if (batch.getNumResults() != 0) { - if (logicalConsumer && isa(logicalConsumer->op)) - return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); - return getWholeBatchProducerKey(batch, result.getResultNumber()); - } - - return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; - } - - return std::nullopt; -} - -std::optional getWholeBatchProducerKeyForDirectBatchResult(Value value) { - auto result = dyn_cast(value); - if (!result) - return std::nullopt; - - auto batch = dyn_cast_or_null(result.getOwner()); - if (!batch || batch.getNumResults() == 0) - return std::nullopt; - - return getWholeBatchProducerKey(batch, result.getResultNumber()); -} - -bool canUseProjectedLaneInput(MaterializerState& state, - SpatComputeBatch consumerBatch, - unsigned inputIndex, - Value input, - ComputeInstance logicalConsumer) { - auto producerResult = dyn_cast(input); - if (!producerResult) - return false; - - auto producerBatch = dyn_cast_or_null(producerResult.getOwner()); - if (!producerBatch || producerBatch.getNumResults() == 0) - return false; - - std::optional match = - getProjectedInputSliceMatch(state, consumerBatch, inputIndex); - if (!match) - return false; - - ProducerKey laneProducer = - getBatchLaneProducerKey(producerBatch, logicalConsumer.laneStart, 1, producerResult.getResultNumber()); - return isProjectedInputSliceCompatibleWithProducerFragments( - consumerBatch, *match, laneProducer, logicalConsumer.laneStart); -} - -SmallVector collectProducerKeysForBatchInputDestinations(MaterializerState& state, - SpatComputeBatch consumerBatch, - unsigned inputIndex, - Value input, - ComputeInstance logicalConsumer) { - if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input)) { - if (!canUseProjectedLaneInput(state, consumerBatch, inputIndex, input, logicalConsumer)) { - auto producerBatch = cast(wholeBatchProducer->instance.op); - SmallVector keys; - for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) - keys.push_back(getBatchLaneProducerKey(producerBatch, lane, 1, wholeBatchProducer->resultIndex)); - return keys; - } - } - - return collectProducerKeysForDestinations(input, logicalConsumer); -} - -class CpuUnionFind { -public: - void insert(CpuId cpu) { parent.try_emplace(cpu, cpu); } - - CpuId find(CpuId cpu) { - insert(cpu); - CpuId p = parent.lookup(cpu); - if (p == cpu) - return cpu; - CpuId root = find(p); - parent[cpu] = root; - return root; - } - - void unite(CpuId lhs, CpuId rhs) { - CpuId lhsRoot = find(lhs); - CpuId rhsRoot = find(rhs); - if (lhsRoot == rhsRoot) - return; - if (rhsRoot < lhsRoot) - std::swap(lhsRoot, rhsRoot); - parent[rhsRoot] = lhsRoot; - } - -private: - DenseMap parent; -}; - -LogicalResult buildMaterializationWorkStreams(MaterializerState& state) { - DenseMap> scheduledInstancesByCpu; - for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { - state.oldComputeOps.insert(instance.op); - scheduledInstancesByCpu[cpu].push_back(instance); - state.logicalInstancesByCpu.try_emplace(cpu); - } - - for (auto& [cpu, scheduledInstances] : scheduledInstancesByCpu) { - llvm::sort(scheduledInstances, [&](const ComputeInstance& lhs, const ComputeInstance& rhs) { - auto lhsIt = state.schedule.computeToCpuSlotMap.find(lhs); - auto rhsIt = state.schedule.computeToCpuSlotMap.find(rhs); - assert(lhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); - assert(rhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); - return lhsIt->second < rhsIt->second; - }); - - SmallVector& logicalInstances = state.logicalInstancesByCpu[cpu]; - SlotId logicalSlot = 0; - for (const ComputeInstance& instance : scheduledInstances) { - LogicalSlotRange range {logicalSlot, 1}; - if (isa(instance.op)) - range.count = instance.laneCount; - - state.scheduledInstanceToLogicalSlots[instance] = range; - - if (isa(instance.op)) { - for (uint32_t localLane = 0; localLane < instance.laneCount; ++localLane, ++logicalSlot) { - uint32_t logicalLane = instance.laneStart + localLane; - ComputeInstance logicalInstance {instance.op, logicalLane, 1}; - logicalInstances.push_back(logicalInstance); - state.logicalInstanceToScheduledChunk[logicalInstance] = instance; - } - continue; - } - - logicalInstances.push_back(instance); - ++logicalSlot; - } - } - - return success(); -} - -LogicalResult buildMaterializationClassesFromScheduleEquivalence(MaterializerState& state) { - DenseSet usedCpus; - for (const auto& entry : state.schedule.cpuToLastComputeMap) - usedCpus.insert(entry.first); - for (const auto& entry : state.schedule.computeToCpuMap) - usedCpus.insert(entry.second); - - CpuUnionFind unionFind; - for (CpuId cpu : usedCpus) - unionFind.insert(cpu); - - for (const auto& [cpu, equivalentCpus] : state.schedule.equivalentClass) { - if (!usedCpus.contains(cpu)) - continue; - for (CpuId equivalentCpu : equivalentCpus) - if (usedCpus.contains(equivalentCpu)) - unionFind.unite(cpu, equivalentCpu); - } - - DenseMap> groupsByRoot; - for (CpuId cpu : usedCpus) - groupsByRoot[unionFind.find(cpu)].push_back(cpu); - - SmallVector roots; - roots.reserve(groupsByRoot.size()); - for (const auto& entry : groupsByRoot) - roots.push_back(entry.first); - llvm::sort(roots); - - state.classes.reserve(roots.size()); - for (CpuId root : roots) { - MaterializedClass materializedClass; - materializedClass.id = state.classes.size(); - materializedClass.cpus = groupsByRoot.lookup(root); - llvm::sort(materializedClass.cpus); - materializedClass.isBatch = materializedClass.cpus.size() > 1; - for (auto [lane, cpu] : llvm::enumerate(materializedClass.cpus)) { - materializedClass.cpuToLane[cpu] = static_cast(lane); - state.cpuToClass[cpu] = materializedClass.id; - } - state.classes.push_back(std::move(materializedClass)); - } - - return success(); -} - -LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState& state) { - for (const MaterializedClass& materializedClass : state.classes) { - if (materializedClass.cpus.empty()) - continue; - - auto referenceIt = state.logicalInstancesByCpu.find(materializedClass.cpus.front()); - if (referenceIt == state.logicalInstancesByCpu.end()) - return state.func.emitError("missing logical stream for materialized class reference CPU"); - - ArrayRef referenceStream(referenceIt->second); - for (CpuId cpu : materializedClass.cpus) { - auto streamIt = state.logicalInstancesByCpu.find(cpu); - if (streamIt == state.logicalInstancesByCpu.end()) - return state.func.emitError("missing logical stream for materialized class CPU"); - - ArrayRef stream(streamIt->second); - if (stream.size() != referenceStream.size()) - return state.func.emitError("materialized class CPUs have mismatched logical stream lengths"); - - for (auto [slot, zipped] : llvm::enumerate(llvm::zip(referenceStream, stream))) { - const ComputeInstance& referenceInstance = std::get<0>(zipped); - const ComputeInstance& currentInstance = std::get<1>(zipped); - if (referenceInstance.op != currentInstance.op) - return state.func.emitError("materialized class logical slot source op mismatch"); - if (isa(referenceInstance.op) != isa(currentInstance.op)) - return state.func.emitError("materialized class logical slot batch/scalar mismatch"); - (void) slot; - } - } - } - - return success(); -} - -LogicalResult forEachLogicalConsumerInMaterializationOrder( - MaterializerState& state, - llvm::function_ref - callback) { - for (const ComputeInstance& scheduledInstance : state.schedule.dominanceOrderCompute) { - auto cpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return scheduledInstance.op->emitError("missing CPU assignment for scheduled logical-slot iteration"); - - auto rangeIt = state.scheduledInstanceToLogicalSlots.find(scheduledInstance); - if (rangeIt == state.scheduledInstanceToLogicalSlots.end()) - return scheduledInstance.op->emitError("missing logical slot range for scheduled logical-slot iteration"); - - CpuId cpu = cpuIt->second; - ClassId classId = state.cpuToClass.lookup(cpu); - LogicalSlotRange range = rangeIt->second; - auto streamIt = state.logicalInstancesByCpu.find(cpu); - if (streamIt == state.logicalInstancesByCpu.end()) - return scheduledInstance.op->emitError("missing logical stream for CPU"); - for (SlotId logicalSlot = range.start; logicalSlot < range.start + range.count; ++logicalSlot) { - if (logicalSlot >= streamIt->second.size()) - return scheduledInstance.op->emitError("missing logical slot materialization instance"); - if (failed(callback(cpu, classId, scheduledInstance, streamIt->second[logicalSlot], logicalSlot))) - return failure(); - } - } - - return success(); -} - -bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps); - -LogicalResult collectHostOutputs(MaterializerState& state) { - DenseSet seenOutputs; - SmallVector orderedOutputs; - DenseMap preferredOwners; - - for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { - auto cpuIt = state.schedule.computeToCpuMap.find(instance); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - - ClassId classId = state.cpuToClass.lookup(cpuIt->second); - MaterializedClass& materializedClass = state.classes[classId]; - for (Value output : getComputeInstanceOutputValuesCached(state, instance)) { - if (!hasLiveExternalUseCached(state, output)) - continue; - - if (seenOutputs.insert(output).second) { - orderedOutputs.push_back(output); - preferredOwners[output] = classId; - continue; - } - - auto batch = dyn_cast_or_null(output.getDefiningOp()); - if (!batch || batch.getNumResults() == 0) - continue; - - ClassId currentOwner = preferredOwners.lookup(output); - bool terminalHost = isTerminalHostBatchOutput(output, state.oldComputeOps); - if (terminalHost) { - // Terminal resultful batch outputs are still published through scalar - // host-output slots unless the materialized batch class owns the output - // directly. Selecting an arbitrary batch class as the host owner would - // require a projection-aware batch publication path, which the - // materializer does not currently implement. - if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) - preferredOwners[output] = classId; - continue; - } - - if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) - preferredOwners[output] = classId; - } - } - - for (MaterializedClass& materializedClass : state.classes) { - materializedClass.hostOutputs.clear(); - materializedClass.hostOutputToResultIndex.clear(); - } - state.hostOutputOwners.clear(); - - for (Value output : orderedOutputs) { - ClassId ownerClassId = preferredOwners.lookup(output); - MaterializedClass& ownerClass = state.classes[ownerClassId]; - ownerClass.hostOutputToResultIndex[output] = ownerClass.hostOutputs.size(); - ownerClass.hostOutputs.push_back(output); - state.hostOutputOwners[output] = ownerClassId; - } - - return success(); -} - -LogicalResult createEmptyMaterializedOps(MaterializerState& state) { - Location loc = state.func.getLoc(); - Block& funcBlock = state.func.getBody().front(); - - Operation* firstOldCompute = nullptr; - for (Operation& op : funcBlock) { - if (state.oldComputeOps.contains(&op)) { - firstOldCompute = &op; - break; - } - } - - if (firstOldCompute) - state.rewriter.setInsertionPoint(firstOldCompute); - else - state.rewriter.setInsertionPointToStart(&funcBlock); - - for (MaterializedClass& materializedClass : state.classes) { - SmallVector resultTypes; - resultTypes.reserve(materializedClass.hostOutputs.size()); - for (Value output : materializedClass.hostOutputs) - resultTypes.push_back(output.getType()); - - if (!materializedClass.isBatch) { - auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); - compute.getProperties().setOperandSegmentSizes({0, 0}); - auto coreIdAttr = - pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); - if (failed(coreIdAttr)) - return failure(); - compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); - Block* body = state.rewriter.createBlock(&compute.getBody()); - state.rewriter.setInsertionPointToEnd(body); - SmallVector placeholderOutputs; - placeholderOutputs.reserve(resultTypes.size()); - for (Type resultType : resultTypes) { - auto tensorType = dyn_cast(resultType); - if (!tensorType || !tensorType.hasStaticShape()) { - compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); - return failure(); - } - placeholderOutputs.push_back( - tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); - } - SpatYieldOp::create(state.rewriter, loc, ValueRange(placeholderOutputs)); - materializedClass.op = compute.getOperation(); - materializedClass.body = body; - state.rewriter.setInsertionPointAfter(compute.getOperation()); - continue; - } - - auto batchLaneCountAttr = pim::getCheckedI32Attr( - state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); - if (failed(batchLaneCountAttr)) - return failure(); - auto batch = SpatScheduledComputeBatch::create( - state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); - batch.getProperties().setOperandSegmentSizes({0, 0}); - auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); - if (failed(coreIds)) - return failure(); - batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(*coreIds)); - - SmallVector blockArgTypes {state.rewriter.getIndexType()}; - SmallVector blockArgLocs {loc}; - llvm::append_range(blockArgTypes, resultTypes); - blockArgLocs.append(resultTypes.size(), loc); - Block* body = - state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - state.rewriter.setInsertionPointToEnd(body); - if (resultTypes.empty()) - SpatYieldOp::create(state.rewriter, loc, ValueRange {}); - else - SpatInParallelOp::create(state.rewriter, loc); - materializedClass.op = batch.getOperation(); - materializedClass.body = body; - state.rewriter.setInsertionPointAfter(batch.getOperation()); - } - - return success(); -} - -BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { - auto it = materializedClass.weightArgs.find(weight); - if (it != materializedClass.weightArgs.end()) - return it->second; - - unsigned weightIndex = materializedClass.weights.size(); - materializedClass.weights.push_back(weight); - - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc()); - assert(arg && "expected compute body while inserting a weight"); - materializedClass.weightArgs[weight] = std::get<1>(*arg); - return std::get<1>(*arg); - } - - auto batch = cast(materializedClass.op); - auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc()); - assert(arg && "expected compute_batch body while inserting a weight argument"); - materializedClass.weightArgs[weight] = std::get<1>(*arg); - return std::get<1>(*arg); -} - -BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) { - auto it = materializedClass.inputArgs.find(input); - if (it != materializedClass.inputArgs.end()) - return it->second; - - materializedClass.inputs.push_back(input); - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); - assert(arg && "expected compute body while inserting an input"); - materializedClass.inputArgs[input] = std::get<1>(*arg); - return std::get<1>(*arg); - } - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); - assert(arg && "expected compute_batch body while inserting an input argument"); - materializedClass.inputArgs[input] = std::get<1>(*arg); - return std::get<1>(*arg); - } - llvm_unreachable("Cannot reach here"); -} - -// ----------------------------------------------------------------------------- -// Materialized-class value localization helpers. -// ----------------------------------------------------------------------------- - -Region* getParentRegion(Value value) { - if (auto blockArg = dyn_cast(value)) - return blockArg.getOwner()->getParent(); - if (Operation* definingOp = value.getDefiningOp()) - return definingOp->getParentRegion(); - return nullptr; -} - -bool isDefinedInsideRegion(Value value, Region& region) { - Region* parentRegion = getParentRegion(value); - return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); -} - -Operation* getEnclosingSpatialComputeLikeOp(Value value) { - Block* block = nullptr; - if (auto blockArg = dyn_cast(value)) - block = blockArg.getOwner(); - else if (Operation* definingOp = value.getDefiningOp()) - block = definingOp->getBlock(); - - if (!block) - return nullptr; - - for (Operation* current = block->getParentOp(); current; current = current->getParentOp()) - if (isa(current)) - return current; - return nullptr; -} - -bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { - if (!isa(value.getType())) - return true; - if (isConstantLike(value)) - return true; - - Region& targetRegion = *targetClass.body->getParent(); - return isDefinedInsideRegion(value, targetRegion); -} - -bool isTensorValueDefinedInDifferentMaterializedClass(Value value, const MaterializedClass& targetClass) { - if (!isa(value.getType()) || isTensorValueLocalToMaterializedClass(value, targetClass)) - return false; - - Operation* owner = getEnclosingSpatialComputeLikeOp(value); - return owner && owner != targetClass.op; -} - -std::optional getRegionIndexInParentOp(Region* region) { - Operation* parent = region ? region->getParentOp() : nullptr; - if (!parent) - return std::nullopt; - - for (auto [index, candidate] : llvm::enumerate(parent->getRegions())) - if (&candidate == region) - return static_cast(index); - return std::nullopt; -} - -std::optional getBlockIndexInRegion(Block* block) { - Region* region = block ? block->getParent() : nullptr; - if (!region) - return std::nullopt; - - for (auto [index, candidate] : llvm::enumerate(region->getBlocks())) - if (&candidate == block) - return static_cast(index); - return std::nullopt; -} - -Block* getBlockByIndex(Region& region, unsigned blockIndex) { - unsigned index = 0; - for (Block& block : region) { - if (index == blockIndex) - return █ - ++index; - } - return nullptr; -} - -static bool isValueLegalInMaterializedClassBody(Value value, const MaterializedClass& targetClass) { - if (isConstantLike(value)) - return true; - - Region& targetRegion = *targetClass.body->getParent(); - return isDefinedInsideRegion(value, targetRegion); -} - -std::string stringifyOperationForMaterializerDebug(Operation* op) { - if (!op) - return std::string(""); - std::string storage; - llvm::raw_string_ostream stream(storage); - op->print(stream); - return storage; -} - -std::string stringifyValueForMaterializerDebug(Value value) { - std::string storage; - llvm::raw_string_ostream stream(storage); - value.print(stream); - return storage; -} - -std::string truncateMaterializerDebugString(std::string text, size_t limit = 1200) { - for (char& ch : text) - if (ch == '\n' || ch == '\r' || ch == '\t') - ch = ' '; - - if (text.size() <= limit) - return text; - text.resize(limit); - text += "..."; - return text; -} - -std::string formatMaterializerOperandListInline(Operation* op, const MaterializedClass& targetClass) { - if (!op) - return std::string(""); - - std::string storage; - llvm::raw_string_ostream stream(storage); - for (OpOperand& operand : op->getOpOperands()) { - if (operand.getOperandNumber() != 0) - stream << " | "; - Value value = operand.get(); - stream << "operand#" << operand.getOperandNumber() << " type=" << value.getType() - << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) - << " value=" << stringifyValueForMaterializerDebug(value); - if (auto blockArg = dyn_cast(value)) { - stream << " blockArg#" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - stream << " ownerOp='" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - stream << " definingOp='" << definingOp->getName() << "'"; - } - } - return truncateMaterializerDebugString(stream.str()); -} - -std::string formatMaterializerParentChainInline(Operation* op) { - if (!op) - return std::string(""); - - std::string storage; - llvm::raw_string_ostream stream(storage); - unsigned depth = 0; - for (Operation* current = op; current; current = current->getParentOp()) { - if (depth != 0) - stream << " <- "; - stream << "[" << depth++ << "]" << current->getName(); - } - return truncateMaterializerDebugString(stream.str()); -} - -void attachMaterializerOperationPrintNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { - if (!op) - return; - diagnostic.attachNote(op->getLoc()) << label << ":\n" << stringifyOperationForMaterializerDebug(op); -} - -void attachMaterializerParentChainNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { - if (!op) - return; - - std::string storage; - llvm::raw_string_ostream stream(storage); - unsigned depth = 0; - for (Operation* current = op; current; current = current->getParentOp()) - stream << " [" << depth++ << "] " << current->getName() << "\n"; - - diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); -} - -void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, - Operation* op, - const MaterializedClass& targetClass, - StringRef label) { - if (!op) - return; - - std::string storage; - llvm::raw_string_ostream stream(storage); - for (OpOperand& operand : op->getOpOperands()) { - Value value = operand.get(); - stream << " operand#" << operand.getOperandNumber() << " type=" << value.getType() - << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) - << " value=" << stringifyValueForMaterializerDebug(value); - if (auto blockArg = dyn_cast(value)) { - stream << " blockArg#" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - stream << " ownerOp='" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - stream << " definingOp='" << definingOp->getName() << "'"; - } - stream << "\n"; - } - - diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); -} - -void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value value, StringRef label) { - if (auto blockArg = dyn_cast(value)) { - if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic.attachNote(owner->getLoc()) - << label << " is block argument #" << blockArg.getArgNumber() << " of '" << owner->getName() - << "' with type " << blockArg.getType(); - else - diagnostic.attachNote(UnknownLoc::get(value.getContext())) - << label << " is a top-level block argument #" << blockArg.getArgNumber() - << " with type " << blockArg.getType(); - return; - } - - if (Operation* definingOp = value.getDefiningOp()) { - diagnostic.attachNote(definingOp->getLoc()) - << label << " is defined by '" << definingOp->getName() << "' with result type " << value.getType(); - return; - } - - diagnostic.attachNote(UnknownLoc::get(value.getContext())) - << label << " has no defining operation and is not a block argument, type " << value.getType(); -} - -void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { - Block& body = *targetClass.body; - diagnostic.attachNote(targetClass.op->getLoc()) - << "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName() - << "' body has " << body.getNumArguments() << " block arguments and " - << std::distance(body.begin(), body.end()) << " top-level operations"; -} - -FailureOr rematerializeIndexValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Location loc, - IRMapping* mapper = nullptr); - -FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& state, - MaterializedClass& targetClass, - OpFoldResult value, - Location loc, - IRMapping* mapper = nullptr) { - if (auto attr = dyn_cast(value)) - return OpFoldResult(attr); - - FailureOr rematerialized = rematerializeIndexValueInClass(state, targetClass, cast(value), loc, mapper); - if (failed(rematerialized)) - return failure(); - return OpFoldResult(*rematerialized); -} - -FailureOr rematerializeIndexValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Location loc, - IRMapping* mapper) { - Value originalValue = value; - bool mapperHadOriginalValue = false; - Value mappedOriginalValue; - - if (mapper && mapper->contains(value)) { - mapperHadOriginalValue = true; - Value mapped = mapper->lookup(value); - mappedOriginalValue = mapped; - if (isValueLegalInMaterializedClassBody(mapped, targetClass) || isConstantLike(mapped)) - return mapped; - value = mapped; - } - - if (isValueLegalInMaterializedClassBody(value, targetClass)) - return value; - - if (!value.getType().isIndex()) - return targetClass.op->emitError("cannot rematerialize non-index external value in materialized class body") - << " type=" << value.getType(); - - if (auto constantIndex = value.getDefiningOp()) - return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantIndex.value()); - - APInt constantValue; - if (matchPattern(value, m_ConstantInt(&constantValue))) { - if (!constantValue.isSignedIntN(64)) - return targetClass.op->emitError("cannot rematerialize out-of-range index constant") - << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); - return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantValue.getSExtValue()); - } - - if (auto affineApply = value.getDefiningOp()) { - SmallVector remappedOperands; - remappedOperands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, operand, loc, mapper); - if (failed(remapped)) - return failure(); - remappedOperands.push_back(*remapped); - } - return createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func); - } - - if (auto addOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, addOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, addOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::AddIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto subOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, subOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, subOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::SubIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto mulOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::MulIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto divOp = value.getDefiningOp()) { - FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, divOp.getLhs(), loc, mapper); - FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, divOp.getRhs(), loc, mapper); - if (failed(lhs) || failed(rhs)) - return failure(); - return arith::DivUIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); - } - - if (auto extractOp = value.getDefiningOp()) { - SmallVector remappedIndices; - remappedIndices.reserve(extractOp.getIndices().size()); - for (Value index : extractOp.getIndices()) { - FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, index, loc, mapper); - if (failed(remapped)) - return failure(); - remappedIndices.push_back(*remapped); - } - - Value tensor = extractOp.getTensor(); - if (!isConstantLike(tensor) && !isValueLegalInMaterializedClassBody(tensor, targetClass)) - return targetClass.op->emitError("cannot rematerialize indexed table lookup from external non-constant tensor") - << " tensorType=" << tensor.getType(); - return tensor::ExtractOp::create(state.rewriter, loc, tensor, remappedIndices).getResult(); - } - - if (auto blockArg = dyn_cast(value)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body"); - diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() - << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; - if (Operation* owner = blockArg.getOwner()->getParentOp()) { - diagnostic << " ownerOp='" << owner->getName() << "'"; - diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) << "\""; - diagnostic << " ownerChain=\"" << formatMaterializerParentChainInline(owner) << "\""; - } - diagnostic << " targetIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; - if (mapper) { - diagnostic << " mapperPresent=1 mapperHadOriginal=" << (mapperHadOriginalValue ? 1 : 0); - if (mapperHadOriginalValue) - diagnostic << " mappedType=" << mappedOriginalValue.getType(); - } else { - diagnostic << " mapperPresent=0"; - } - attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); - if (value != originalValue) - attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); - if (mapperHadOriginalValue && mappedOriginalValue != value) - attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value"); - if (Operation* owner = blockArg.getOwner()->getParentOp()) { - attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op"); - attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain"); - } - attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return failure(); - } - - InFlightDiagnostic diagnostic = - targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body"); - diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='" - << targetClass.op->getName() << "'"; - attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); - if (value != originalValue) - attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return failure(); -} - -InFlightDiagnostic emitNonLocalMaterializedClassValueDiagnostic(Operation* anchor, - const MaterializedClass& targetClass, - StringRef context, - Value value, - std::optional producer = std::nullopt) { - InFlightDiagnostic diagnostic = anchor->emitError(context) << " into target class " << targetClass.id; - - if (producer) { - diagnostic << " from '" << producer->instance.op->getName() << "' resultIndex=" << producer->resultIndex - << " laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount; - } else if (auto result = dyn_cast(value)) { - diagnostic << " from '" << result.getOwner()->getName() << "' resultIndex=" << result.getResultNumber(); - } else if (auto blockArg = dyn_cast(value)) { - diagnostic << " from block argument #" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic << " of '" << owner->getName() << "'"; - } - - if (Operation* definingOp = value.getDefiningOp()) - diagnostic.attachNote(definingOp->getLoc()) << "offending tensor producer is '" << definingOp->getName() << "'"; - return diagnostic; -} - -FailureOr rematerializeTensorValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - IRMapping* mapper) { - auto extractSlice = value.getDefiningOp(); - if (extractSlice) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, targetClass, extractSlice.getSource(), anchor, context, std::nullopt, mapper); - if (failed(localizedSource)) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(extractSlice.getMixedOffsets().size()); - sizes.reserve(extractSlice.getMixedSizes().size()); - strides.reserve(extractSlice.getMixedStrides().size()); - - for (OpFoldResult offset : extractSlice.getMixedOffsets()) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, offset, anchor->getLoc(), mapper); - if (failed(localized)) - return failure(); - offsets.push_back(*localized); - } - for (OpFoldResult size : extractSlice.getMixedSizes()) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, size, anchor->getLoc(), mapper); - if (failed(localized)) - return failure(); - sizes.push_back(*localized); - } - for (OpFoldResult stride : extractSlice.getMixedStrides()) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, stride, anchor->getLoc(), mapper); - if (failed(localized)) - return failure(); - strides.push_back(*localized); - } - - return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides) - .getResult(); - } - - if (auto collapseShape = value.getDefiningOp()) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, targetClass, collapseShape.getSrc(), anchor, context, std::nullopt, mapper); - if (failed(localizedSource)) - return failure(); - return tensor::CollapseShapeOp::create( - state.rewriter, anchor->getLoc(), *localizedSource, collapseShape.getReassociationIndices()) - .getResult(); - } - - return failure(); -} - -FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef context, - std::optional producer, - IRMapping* mapper) { - if (mapper && mapper->contains(value)) - value = mapper->lookup(value); - - if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) - return value; - - if (value.getDefiningOp() || value.getDefiningOp()) { - FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); - if (failed(rematerialized)) - return failure(); - return *rematerialized; - } - - if (isTensorValueDefinedInDifferentMaterializedClass(value, targetClass)) { - emitNonLocalMaterializedClassValueDiagnostic(anchor, targetClass, context, value, producer); - return failure(); - } - - return appendInput(state, targetClass, value); -} - -std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, - Operation* anchor, - BlockArgument externalArg) { - Block* sourceBlock = externalArg.getOwner(); - Region* sourceRegion = sourceBlock ? sourceBlock->getParent() : nullptr; - Operation* sourceParent = sourceRegion ? sourceRegion->getParentOp() : nullptr; - if (!sourceParent || !anchor) - return std::nullopt; - - std::optional sourceRegionIndex = getRegionIndexInParentOp(sourceRegion); - std::optional sourceBlockIndex = getBlockIndexInRegion(sourceBlock); - if (!sourceRegionIndex || !sourceBlockIndex) - return std::nullopt; - - for (Operation* current = anchor->getParentOp(); current && current != targetClass.op; - current = current->getParentOp()) { - if (current->getName() != sourceParent->getName()) - continue; - if (current->getNumRegions() <= *sourceRegionIndex) - continue; - - Region& localRegion = current->getRegion(*sourceRegionIndex); - Block* localBlock = getBlockByIndex(localRegion, *sourceBlockIndex); - if (!localBlock || localBlock->getNumArguments() <= externalArg.getArgNumber()) - continue; - - BlockArgument localArg = localBlock->getArgument(externalArg.getArgNumber()); - if (localArg.getType() != externalArg.getType()) - continue; - if (!isValueLegalInMaterializedClassBody(localArg, targetClass)) - continue; - return localArg; - } - - return std::nullopt; -} - -FailureOr localizeMaterializedClassOperand(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Operation* anchor, - StringRef tensorContext, - StringRef genericContext, - IRMapping* mapper) { - if (mapper && mapper->contains(value)) - value = mapper->lookup(value); - - if (auto blockArg = dyn_cast(value)) - if (std::optional localArg = mapExternalRegionBlockArgumentToLocalClone(targetClass, anchor, blockArg)) - return *localArg; - - if (isa(value.getType())) - return materializeTensorValueForMaterializedClassUse(state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); - - if (isValueLegalInMaterializedClassBody(value, targetClass)) - return value; - - if (value.getType().isIndex()) - return rematerializeIndexValueInClass(state, targetClass, value, anchor->getLoc(), mapper); - - InFlightDiagnostic diagnostic = anchor->emitError(genericContext); - diagnostic << " type=" << value.getType(); - if (auto blockArg = dyn_cast(value)) { - diagnostic << " blockArg#" << blockArg.getArgNumber(); - if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic.attachNote(owner->getLoc()) << "block argument belongs to '" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - diagnostic.attachNote(definingOp->getLoc()) << "unsupported external operand producer is '" << definingOp->getName() - << "'"; - } - return failure(); -} - -// ----------------------------------------------------------------------------- -// Tensor packing helpers. -// ----------------------------------------------------------------------------- - -struct Dim0SliceParams { - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -Dim0SliceParams -buildDim0SliceParams(OpBuilder& builder, RankedTensorType referenceType, OpFoldResult firstOffset, int64_t firstSize) { - Dim0SliceParams params; - params.offsets.reserve(referenceType.getRank()); - params.sizes.reserve(referenceType.getRank()); - params.strides.reserve(referenceType.getRank()); - - params.offsets.push_back(firstOffset); - params.sizes.push_back(builder.getIndexAttr(firstSize)); - params.strides.push_back(builder.getIndexAttr(1)); - - for (int64_t dim = 1; dim < referenceType.getRank(); ++dim) { - params.offsets.push_back(builder.getIndexAttr(0)); - params.sizes.push_back(builder.getIndexAttr(referenceType.getDimSize(dim))); - params.strides.push_back(builder.getIndexAttr(1)); - } - - return params; -} - -Value createDim0ExtractSlice( - MaterializerState& state, Location loc, Value source, OpFoldResult firstOffset, int64_t firstSize) { - auto sourceType = cast(source.getType()); - Dim0SliceParams params = buildDim0SliceParams(state.rewriter, sourceType, firstOffset, firstSize); - return tensor::ExtractSliceOp::create(state.rewriter, loc, source, params.offsets, params.sizes, params.strides) - .getResult(); -} - -FailureOr createDim0ExtractSliceInClass(MaterializerState& state, - MaterializedClass& targetClass, - Location loc, - Value source, - OpFoldResult firstOffset, - int64_t firstSize) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - source, - targetClass.op, - "createDim0ExtractSliceInClass tried to reuse a tensor from another materialized class"); - if (failed(localizedSource)) - return failure(); - FailureOr localizedOffset = - rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); - if (failed(localizedOffset)) - return failure(); - return createDim0ExtractSlice(state, loc, *localizedSource, *localizedOffset, firstSize); -} - -Value createStaticExtractSlice(MaterializerState& state, - Location loc, - Value source, - ArrayRef sliceOffsets, - ArrayRef resultShape) { - auto sourceType = cast(source.getType()); - assert(sliceOffsets.size() == static_cast(sourceType.getRank()) && "offset rank mismatch"); - assert(resultShape.size() == static_cast(sourceType.getRank()) && "result rank mismatch"); - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(sourceType.getRank()); - sizes.reserve(sourceType.getRank()); - strides.reserve(sourceType.getRank()); - - for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - offsets.push_back(sliceOffsets[dim]); - sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim])); - strides.push_back(state.rewriter.getIndexAttr(1)); - } - - return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); -} - -FailureOr createStaticExtractSliceInClass(MaterializerState& state, - MaterializedClass& targetClass, - Location loc, - Value source, - ArrayRef sliceOffsets, - ArrayRef resultShape) { - FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - source, - targetClass.op, - "createStaticExtractSliceInClass tried to reuse a tensor from another materialized class"); - if (failed(localizedSource)) - return failure(); - - SmallVector localizedOffsets; - localizedOffsets.reserve(sliceOffsets.size()); - for (OpFoldResult offset : sliceOffsets) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); - if (failed(localized)) - return failure(); - localizedOffsets.push_back(*localized); - } - return createStaticExtractSlice(state, loc, *localizedSource, localizedOffsets, resultShape); -} - -Value createIndexedIndexValue(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc, - std::optional preferredPeriod = std::nullopt, - bool allowExhaustiveTiledSearch = true); - -FailureOr> buildProjectedFragmentOffsetsInClass(MaterializerState& state, - MaterializedClass& targetClass, - const ProjectedTransferDescriptor& descriptor, - Value flatFragmentIndex, - Location loc) { - FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, flatFragmentIndex, loc); - if (failed(localizedIndex)) - return failure(); - SmallVector fragmentOffsets; - fragmentOffsets.reserve(descriptor.layout.fragmentShape.size()); - for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) - fragmentOffsets.push_back(createIndexedIndexValue(state, - targetClass.op, - dimOffsets, - *localizedIndex, - loc, - static_cast(descriptor.layout.payloadFragmentCount), - /*allowExhaustiveTiledSearch=*/false)); - return fragmentOffsets; -} - -Value createDim0InsertSlice( - MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { - auto fragmentType = cast(fragment.getType()); - Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); - return tensor::InsertSliceOp::create( - state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides) - .getResult(); -} - -FailureOr createDim0InsertSliceInClass(MaterializerState& state, - MaterializedClass& targetClass, - Location loc, - Value fragment, - Value destination, - OpFoldResult firstOffset) { - FailureOr localizedFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fragment, - targetClass.op, - "createDim0InsertSliceInClass tried to reuse a fragment tensor from another materialized class"); - if (failed(localizedFragment)) - return failure(); - FailureOr localizedDestination = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - destination, - targetClass.op, - "createDim0InsertSliceInClass tried to reuse a destination tensor from another materialized class"); - if (failed(localizedDestination)) - return failure(); - FailureOr localizedOffset = - rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); - if (failed(localizedOffset)) - return failure(); - return createDim0InsertSlice(state, loc, *localizedFragment, *localizedDestination, *localizedOffset); -} - -void createDim0ParallelInsertSlice( - MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { - auto fragmentType = cast(fragment.getType()); - Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); - tensor::ParallelInsertSliceOp::create( - state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides); -} - -Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc) { - if (dim0Size == 1) - return index; - - Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size); - return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); -} - -FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value index, - int64_t dim0Size, - Location loc) { - FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, index, loc); - if (failed(localizedIndex)) - return failure(); - if (dim0Size == 1) - return *localizedIndex; - - Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size); - return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult(); -} - -bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { - return lhs.instance.op == rhs.instance.op && lhs.resultIndex == rhs.resultIndex; -} - -bool containsProducerKey(ProducerKey outer, ProducerKey inner) { - if (!sameProducerResult(outer, inner)) - return false; - if (!isa(outer.instance.op)) - return false; - if (outer.instance.laneCount == 0 || inner.instance.laneCount == 0) - return false; - - uint32_t outerStart = outer.instance.laneStart; - uint32_t outerEnd = outerStart + outer.instance.laneCount; - uint32_t innerStart = inner.instance.laneStart; - uint32_t innerEnd = innerStart + inner.instance.laneCount; - - return outerStart <= innerStart && innerEnd <= outerEnd; -} - -std::optional extractPackedProducerSlice(MaterializerState& state, - MaterializedClass& materializedClass, - ProducerKey packedKey, - Value packed, - ProducerKey requestedKey) { - if (!containsProducerKey(packedKey, requestedKey)) - return std::nullopt; - - auto packedType = dyn_cast(packed.getType()); - if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0) - return std::nullopt; - - if (packedKey.instance.laneCount == 0) - return std::nullopt; - - int64_t packedRows = packedType.getDimSize(0); - if (packedRows % static_cast(packedKey.instance.laneCount) != 0) - return std::nullopt; - - int64_t rowsPerLane = packedRows / static_cast(packedKey.instance.laneCount); - int64_t rowOffset = - static_cast(requestedKey.instance.laneStart - packedKey.instance.laneStart) * rowsPerLane; - int64_t rowCount = static_cast(requestedKey.instance.laneCount) * rowsPerLane; - - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - - Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); - return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); -} - -std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { - auto producerIt = exactValues.find(key); - if (producerIt == exactValues.end()) - return std::nullopt; - - auto valueIt = producerIt->second.find(classId); - if (valueIt == producerIt->second.end()) - return std::nullopt; - - return valueIt->second; -} - -Value getPackedSliceForRunIndex(MaterializerState& state, - Operation* anchor, - Value packed, - RankedTensorType fragmentType, - size_t index, - Location loc) { - int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); - Value firstOffset = getOrCreateIndexConstant(state.constantFolder, anchor, rowOffset); - return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); -} - -FailureOr createReceiveConcatLoop(MaterializerState& state, - MaterializedClass& targetClass, - RankedTensorType concatType, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc); - -using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; -using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; - -FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc); - -bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { - return run.kind == PackedScalarRunKind::DeferredLocalCompute; -} - -size_t getPackedScalarRunReceiveCount(const PackedScalarRunValue& run) { - size_t count = 0; - for (const PackedScalarRunSlot& slot : run.slots) - count += slot.keys.size(); - return count; -} - -LogicalResult validatePackedScalarRunMetadata(Operation* anchor, const PackedScalarRunValue& run) { - if (run.kind == PackedScalarRunKind::DeferredLocalCompute) - return success(); - - size_t receiveCount = getPackedScalarRunReceiveCount(run); - - if (receiveCount == 0) - return anchor->emitError("packed scalar run has no receives"); - - if (failed(run.messages.verify(anchor))) - return failure(); - - if (run.messages.size() != receiveCount) - return anchor->emitError("packed scalar run receive metadata count is inconsistent"); - - return success(); -} - -FailureOr materializePackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc) { - if (run.packed) - return run.packed; - - if (run.kind == PackedScalarRunKind::Materialized) - return targetClass.op->emitError("materialized packed scalar run has no packed value"); - - if (isDeferredLocalPackedScalarRun(run)) - return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc); - - if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) - return failure(); - - FailureOr fullPackedType = - getPackedBatchTensorType(run.fragmentType, getPackedScalarRunReceiveCount(run)); - if (failed(fullPackedType)) - return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); - - auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc); - if (failed(packed)) - return failure(); - run.packed = *packed; - return run.packed; -} - -std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { - for (PackedScalarRunValue& run : packedScalarRuns) { - if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) - continue; - - for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { - std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); - auto exactKeyIt = llvm::find(slot.keys, key); - if ((!contiguousKey || !containsProducerKey(*contiguousKey, key)) && exactKeyIt == slot.keys.end()) - continue; - - FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); - if (failed(slotPackedType)) - return std::nullopt; - - MaterializedClass& materializedClass = state.classes[classId]; - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - - FailureOr packed = - materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); - if (failed(packed)) - return std::nullopt; - - Value slotPacked = - getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); - - if (contiguousKey && *contiguousKey == key) { - record(key, classId, slotPacked); - return slotPacked; - } - - if (contiguousKey && containsProducerKey(*contiguousKey, key)) { - std::optional sliced = - extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key); - if (!sliced) - return std::nullopt; - - record(key, classId, *sliced); - return *sliced; - } - - if (exactKeyIt != slot.keys.end() && key.instance.laneCount == 1) { - size_t keyIndex = static_cast(std::distance(slot.keys.begin(), exactKeyIt)); - Value sliced = getPackedSliceForRunIndex( - state, materializedClass.op, slotPacked, run.fragmentType, keyIndex, (*packed).getLoc()); - record(key, classId, sliced); - return sliced; - } - } - } - - return std::nullopt; -} - -IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { - for (IndexedBatchRunValue& run : indexedBatchRuns) { - if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) - continue; - for (const PackedScalarRunSlot& slot : run.slots) { - if (!llvm::is_contained(slot.keys, key)) - continue; - return &run; - } - } - return nullptr; -} - -std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { - - if (std::optional exact = lookupExact(key, classId)) { - return exact; - } - - if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) - return packedRunValue; - - MaterializedClass& materializedClass = state.classes[classId]; - - for (const auto& [candidateKey, classValues] : exactValues) { - if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key)) - continue; - - auto valueIt = classValues.find(classId); - if (valueIt == classValues.end()) - continue; - - std::optional slice = - extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); - if (!slice) - return std::nullopt; - - record(key, classId, *slice); - return *slice; - } - return std::nullopt; -} - -Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { - SmallVector elements; - elements.reserve(values.size()); - for (int64_t value : values) - elements.push_back(APInt(64, value)); - - auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); - auto attr = DenseIntElementsAttr::get(type, elements); - return getOrCreateConstant(state.constantFolder, anchor, attr, type); -} - -bool allEqual(ArrayRef values) { - assert(!values.empty() && "expected at least one value"); - for (int64_t value : values.drop_front()) - if (value != values.front()) - return false; - return true; -} - -struct IndexedIndexPattern { - int64_t base = 0; - int64_t step = 0; - int64_t period = 1; - int64_t innerStep = 0; - int64_t outerStep = 0; - bool isTiled = false; -}; - -bool matchAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { - assert(!values.empty() && "expected at least one value"); - - pattern.base = values.front(); - pattern.step = values.size() == 1 ? 0 : values[1] - values[0]; - pattern.isTiled = false; - - for (auto [index, value] : llvm::enumerate(values)) { - int64_t expected = pattern.base + pattern.step * static_cast(index); - if (value != expected) - return false; - } - - return true; -} - -bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern, int64_t period) { - assert(!values.empty() && "expected at least one value"); - if (period < 2 || period > static_cast(values.size() / 2)) - return false; - - int64_t base = values.front(); - int64_t innerStep = values[1] - values[0]; - int64_t outerStep = values[period] - values[0]; - - for (auto [index, value] : llvm::enumerate(values)) { - int64_t i = static_cast(index); - int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); - if (value != expected) - return false; - } - - pattern.base = base; - pattern.period = period; - pattern.innerStep = innerStep; - pattern.outerStep = outerStep; - pattern.isTiled = true; - return true; -} - -bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { - assert(!values.empty() && "expected at least one value"); - - for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) - if (matchTiledAffineSequence(values, pattern, period)) - return true; - - return false; -} - -std::optional getIndexedIndexPattern(ArrayRef values, - std::optional preferredPeriod = std::nullopt, - bool allowExhaustiveTiledSearch = true) { - assert(!values.empty() && "expected at least one value"); - - IndexedIndexPattern pattern; - if (matchAffineSequence(values, pattern)) - return pattern; - if (preferredPeriod && matchTiledAffineSequence(values, pattern, *preferredPeriod)) - return pattern; - if (allowExhaustiveTiledSearch && values.size() <= 256 && matchTiledAffineSequence(values, pattern)) - return pattern; - - return std::nullopt; -} - -Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { - MLIRContext* context = state.func.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - - AffineExpr expr; - if (!pattern.isTiled) { - expr = getAffineConstantExpr(pattern.base, context) + d0 * pattern.step; - } - else { - expr = getAffineConstantExpr(pattern.base, context) + d0.floorDiv(pattern.period) * pattern.outerStep - + (d0 % pattern.period) * pattern.innerStep; - } - - AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func); -} - -Value createIndexedIndexValue(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc, - std::optional preferredPeriod, - bool allowExhaustiveTiledSearch) { - assert(!values.empty() && "expected at least one indexed value"); - - if (allEqual(values)) { - return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); - } - - if (std::optional pattern = - getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) - return createAffineIndexValue(state, *pattern, index, loc); - Value table = createIndexTensorConstant(state, anchor, values); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); -} - -Value createIndexedIndexValue( - MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { - assert(!values.empty() && "expected at least one indexed value"); - - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); -} - -OpFoldResult createIndexedOrStaticIndex(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc) { - assert(!values.empty() && "expected at least one indexed value"); - if (allEqual(values)) - return state.rewriter.getIndexAttr(values.front()); - return createIndexedIndexValue(state, anchor, values, index, loc); -} - -Value createIndexedChannelId( - MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { - return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); -} - -Value createIndexedChannelId(MaterializerState& state, - Operation* anchor, - const MessageVector& messages, - Value index, - Location loc, - std::optional preferredPeriod) { - return createIndexedIndexValue( - state, anchor, ArrayRef(messages.channelIds), index, loc, preferredPeriod, true); -} - -Value createIndexedSourceCoreId( - MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { - return createIndexedIndexValue(state, anchor, ArrayRef(messages.sourceCoreIds), index, loc); -} - -Value createIndexedSourceCoreId(MaterializerState& state, - Operation* anchor, - const MessageVector& messages, - Value index, - Location loc, - std::optional preferredPeriod) { - SmallVector widened(messages.sourceCoreIds.begin(), messages.sourceCoreIds.end()); - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); -} - -Value createIndexedTargetCoreId( - MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { - return createIndexedIndexValue(state, anchor, ArrayRef(messages.targetCoreIds), index, loc); -} - -Value createIndexedTargetCoreId(MaterializerState& state, - Operation* anchor, - const MessageVector& messages, - Value index, - Location loc, - std::optional preferredPeriod) { - SmallVector widened(messages.targetCoreIds.begin(), messages.targetCoreIds.end()); - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); -} - -Value createLaneIndexedIndexValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef values, - Location loc) { - assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); - assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); - - auto batch = cast(materializedClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected compute_batch lane argument"); - - return createIndexedIndexValue(state, materializedClass.op, values, *laneArg, loc); -} - -Value createLaneIndexedIndexValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef values, - Location loc) { - assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); - assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); - - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - - return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); -} - -FailureOr> -getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { - SmallVector peers; - peers.reserve(materializedClass.cpus.size()); - for (CpuId cpu : materializedClass.cpus) { - auto streamIt = state.logicalInstancesByCpu.find(cpu); - if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) - return failure(); - peers.push_back(streamIt->second[logicalSlot]); - } - return peers; -} - -Value createOriginalLaneValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef peers, - Location loc) { - assert(!peers.empty() && "expected at least one peer instance"); - if (!materializedClass.isBatch) - return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart); - - auto batch = cast(materializedClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected materialized compute_batch lane argument"); - - SmallVector laneValues; - laneValues.reserve(peers.size()); - for (const ComputeInstance& peer : peers) - laneValues.push_back(peer.laneStart); - - return createIndexedIndexValue(state, materializedClass.op, ArrayRef(laneValues), *laneArg, loc); -} - -bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { - SmallVector worklist {value}; - DenseSet visited; - - while (!worklist.empty()) { - Value current = worklist.pop_back_val(); - if (!visited.insert(current).second) - continue; - - for (OpOperand& use : current.getUses()) { - Operation* owner = use.getOwner(); - if (isInsideOldCompute(owner, oldComputeOps)) - continue; - if (isa(owner)) { - for (Value result : owner->getResults()) - worklist.push_back(result); - continue; - } - return true; - } - } - - return false; -} - -bool hasRealComputeConsumer(Value value, const DenseSet& oldComputeOps) { - SmallVector worklist {value}; - DenseSet visited; - - while (!worklist.empty()) { - Value current = worklist.pop_back_val(); - if (!visited.insert(current).second) - continue; - - for (OpOperand& use : current.getUses()) { - Operation* owner = use.getOwner(); - if (isInsideOldCompute(owner, oldComputeOps)) - continue; - if (isa(owner)) { - for (Value result : owner->getResults()) - worklist.push_back(result); - continue; - } - if (isa(owner)) - continue; - return true; - } - } - - return false; -} - -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); - -bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps) { - auto batch = dyn_cast_or_null(output.getDefiningOp()); - if (!batch || batch.getNumResults() == 0) - return false; - if (!hasLiveExternalUse(output, oldComputeOps)) - return false; - return !hasRealComputeConsumer(output, oldComputeOps); -} - - -void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { - SmallVector& destinations = state.producerDestClasses[key]; - if (!llvm::is_contained(destinations, classId)) - destinations.push_back(classId); -} - -void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet& oldComputeOps) { - SmallVector uses; - for (OpOperand& use : oldValue.getUses()) - uses.push_back(&use); - - for (OpOperand* use : uses) { - Operation* owner = use->getOwner(); - if (isInsideOldCompute(owner, oldComputeOps)) - continue; - use->set(replacement); - } -} - -LogicalResult collectProducerDestinations(MaterializerState& state) { - return forEachLogicalConsumerInMaterializationOrder( - state, - [&](CpuId, ClassId targetClass, ComputeInstance scheduledConsumer, ComputeInstance logicalConsumer, SlotId) - -> LogicalResult { - SmallVector consumerInputs = getComputeInstanceInputs(scheduledConsumer); - for (auto [inputIndex, input] : llvm::enumerate(consumerInputs)) { - SmallVector producerKeys; - if (auto batchConsumer = dyn_cast(logicalConsumer.op)) - producerKeys = collectProducerKeysForBatchInputDestinations( - state, batchConsumer, static_cast(inputIndex), input, logicalConsumer); - else - producerKeys = collectProducerKeysForDestinations(input, logicalConsumer); - - for (ProducerKey producerKey : producerKeys) { - ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producerKey.instance); - auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); - if (producerCpuIt == state.schedule.computeToCpuMap.end()) - return logicalConsumer.op->emitError( - "schedule materialization found an input produced by an unscheduled compute"); - - ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClass == targetClass) { - SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass}; - SmallVector& bucket = state.sameClassConsumerIndex[lookupKey]; - if (!llvm::is_contained(bucket, producerKey)) - bucket.push_back(producerKey); - continue; - } - - appendDestinationClass(state, producerKey, targetClass); - } - } - - return success(); - }); -} - -bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceType, RankedTensorType fragmentType) { - if (offsets.size() != static_cast(sourceType.getRank()) - || offsets.size() != static_cast(fragmentType.getRank())) - return false; - - for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - int64_t offset = offsets[dim]; - if (offset < 0) - return false; - - int64_t sourceDimSize = sourceType.getDimSize(dim); - int64_t fragmentDimSize = fragmentType.getDimSize(dim); - if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) - return false; - if (offset > sourceDimSize - fragmentDimSize) - return false; - } - - return true; -} - - -bool isStaticSliceContainedIn(ArrayRef innerOffsets, - ArrayRef innerSizes, - ArrayRef outerOffsets, - ArrayRef outerSizes) { - if (innerOffsets.size() != innerSizes.size() || outerOffsets.size() != outerSizes.size() - || innerOffsets.size() != outerOffsets.size()) - return false; - - for (size_t dim = 0; dim < innerOffsets.size(); ++dim) { - if (innerSizes[dim] < 0 || outerSizes[dim] < 0) - return false; - - int64_t innerBegin = innerOffsets[dim]; - int64_t innerEnd = innerBegin + innerSizes[dim]; - int64_t outerBegin = outerOffsets[dim]; - int64_t outerEnd = outerBegin + outerSizes[dim]; - if (innerBegin < outerBegin || innerEnd > outerEnd) - return false; - } - - return true; -} - -bool areAllUnitStrides(ArrayRef strides) { - return llvm::all_of(strides, [](int64_t stride) { return stride == 1; }); -} - -static std::optional getStaticForTripCount(scf::ForOp loop) { - std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); - std::optional upperBound = matchConstantIndexValue(loop.getUpperBound()); - std::optional step = matchConstantIndexValue(loop.getStep()); - if (!lowerBound || !upperBound || !step || *step <= 0 || *upperBound < *lowerBound) - return std::nullopt; - - int64_t distance = *upperBound - *lowerBound; - return (distance + *step - 1) / *step; -} - -static SmallVector collectEnclosingStaticProjectedLoops(Operation* op) { - SmallVector loops; - SmallVector reversedLoops; - for (Operation* current = op->getParentOp(); current; current = current->getParentOp()) - if (auto loop = dyn_cast(current)) - reversedLoops.push_back(loop); - - for (scf::ForOp loop : llvm::reverse(reversedLoops)) { - std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); - std::optional step = matchConstantIndexValue(loop.getStep()); - std::optional tripCount = getStaticForTripCount(loop); - if (!lowerBound || !step || !tripCount) - return {}; - loops.push_back(StaticProjectedLoopInfo {.iv = cast(loop.getInductionVar()), - .lowerBound = *lowerBound, - .step = *step, - .tripCount = *tripCount}); - } - return loops; -} - -static bool -isProjectedOffsetValue(Value value, Value laneArg, ArrayRef loops, bool& usesDynamicBinding) { - if (value == laneArg) { - usesDynamicBinding = true; - return true; - } - - for (const StaticProjectedLoopInfo& loop : loops) { - if (value == loop.iv) { - usesDynamicBinding = true; - return true; - } - } - - if (matchPattern(value, m_Constant())) - return true; - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return false; - - bool nestedUsesDynamicBinding = false; - for (Value operand : affineApply.getMapOperands()) { - bool operandUsesDynamicBinding = false; - if (!isProjectedOffsetValue(operand, laneArg, loops, operandUsesDynamicBinding)) - return false; - nestedUsesDynamicBinding = nestedUsesDynamicBinding || operandUsesDynamicBinding; - } - - usesDynamicBinding = usesDynamicBinding || nestedUsesDynamicBinding; - return true; -} - -static std::optional getConstantIndex(OpFoldResult value); - -static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef loopTripCounts) { - unsigned fragmentsPerLogicalSlot = 1; - for (int64_t tripCount : loopTripCounts) { - assert(tripCount > 0 && "projected loop trip counts must be positive"); - fragmentsPerLogicalSlot *= static_cast(tripCount); - } - return fragmentsPerLogicalSlot; -} - -LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) { - if (!layout.fragmentType || layout.fragmentShape.empty()) - return anchor->emitError("projected fragment layout is missing fragment type metadata"); - if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) - return anchor->emitError("projected fragment layout rank does not match fragment type"); - if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) - return anchor->emitError("projected fragment layout has an invalid fragment count"); - if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) - return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); - return success(); -} - -FailureOr -getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) { - if (failed( - verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type"))) - return failure(); - return getPackedBatchTensorType(fragmentType, payloadFragmentCount); -} - -SmallVector, 4> -buildProjectedFragmentOffsetsByDim(ArrayRef> fragmentOffsets, size_t rank) { - SmallVector, 4> fragmentOffsetsByDim(rank); - for (ArrayRef offsets : fragmentOffsets) { - assert(offsets.size() == rank && "projected offset rank mismatch"); - for (size_t dim = 0; dim < rank; ++dim) - fragmentOffsetsByDim[dim].push_back(offsets[dim]); - } - return fragmentOffsetsByDim; -} - -LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) { - if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) - return failure(); - if (!descriptor.payloadType) - return anchor->emitError("projected transfer descriptor is missing payload type"); - if (descriptor.fragmentOffsets.empty()) - return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); - if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) - return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); - for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) - if (dimOffsets.size() != descriptor.fragmentOffsets.size()) - return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); - for (ArrayRef offsets : descriptor.fragmentOffsets) - if (offsets.size() != descriptor.layout.fragmentShape.size()) - return anchor->emitError("projected transfer offset rank does not match fragment rank"); - return success(); -} - -LogicalResult verifyProjectedSendDescriptor(Operation* anchor, - const ProjectedTransferDescriptor& descriptor, - const MessageVector& messages) { - if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) - return failure(); - if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) - return anchor->emitError("projected send descriptor metadata is inconsistent"); - return success(); -} - -LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) { - descriptor.fragmentOffsetsByDim = - buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); - - FailureOr payloadType = - getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); - if (failed(payloadType)) - return failure(); - if (descriptor.payloadType && descriptor.payloadType != *payloadType) - return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); - descriptor.payloadType = *payloadType; - - return verifyProjectedTransferDescriptor(anchor, descriptor); -} - -static FailureOr evaluateProjectedOffsetValue(OpFoldResult value, - Value laneArg, - uint32_t lane, - ArrayRef loops, - ArrayRef loopIterationIndices) { - if (std::optional constant = getConstantIndex(value)) - return *constant; - - Value current = dyn_cast(value); - if (!current) - return failure(); - if (current == laneArg) - return static_cast(lane); - - for (auto [index, loop] : llvm::enumerate(loops)) { - if (current != loop.iv) - continue; - if (index >= loopIterationIndices.size()) - return failure(); - return loop.lowerBound + loopIterationIndices[index] * loop.step; - } - - if (auto affineApply = current.getDefiningOp()) { - return evaluateAffineApply(affineApply, [&](Value operand) { - return evaluateProjectedOffsetValue(operand, laneArg, lane, loops, loopIterationIndices); - }); - } - - return failure(); -} - -static std::optional getConstantIndex(OpFoldResult value) { - if (auto attr = dyn_cast(value)) { - auto intAttr = dyn_cast(attr); - if (!intAttr) - return std::nullopt; - return intAttr.getInt(); - } - - Value operand = dyn_cast(value); - if (!operand) - return std::nullopt; - - if (auto constantIndex = operand.getDefiningOp()) - return constantIndex.value(); - - APInt apInt; - if (matchPattern(operand, m_ConstantInt(&apInt))) { - if (apInt.isNegative()) - return std::nullopt; - return static_cast(apInt.getSExtValue()); - } - - return std::nullopt; -} - -static std::optional matchAffineProjectedInputSlice(SpatComputeBatch batch, - unsigned inputIndex) { - const auto fail = [&](StringRef) -> std::optional { return std::nullopt; }; - - std::optional inputArg = batch.getInputArgument(inputIndex); - std::optional laneArg = batch.getLaneArgument(); - if (!inputArg || !laneArg) - return fail("missing-input-or-lane-arg"); - - if (!inputArg->hasOneUse()) - return fail("input-arg-not-one-use"); - - Operation* user = *inputArg->getUsers().begin(); - auto extract = dyn_cast(user); - if (!extract || extract.getSource() != *inputArg) - return fail("input-user-is-not-direct-extract-slice"); - - auto inputType = dyn_cast(inputArg->getType()); - auto fragmentType = dyn_cast(extract.getResult().getType()); - if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape()) - return fail("non-static-ranked-input-or-fragment"); - - if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank()) - return fail("rank-mismatch-or-rank-zero"); - - SmallVector offsets = extract.getMixedOffsets(); - SmallVector sizes = extract.getMixedSizes(); - SmallVector strides = extract.getMixedStrides(); - - if (offsets.size() != static_cast(inputType.getRank()) - || sizes.size() != static_cast(inputType.getRank()) - || strides.size() != static_cast(inputType.getRank())) - return fail("slice-rank-mismatch"); - - SmallVector loops = collectEnclosingStaticProjectedLoops(extract.getOperation()); - if (extract->getParentOfType() && loops.empty()) - return fail("unsupported-enclosing-loop"); - - bool hasDynamicProjection = false; - for (auto [dim, offset] : llvm::enumerate(offsets)) { - bool usesDynamicBinding = false; - if (auto value = dyn_cast(offset)) { - if (!isProjectedOffsetValue(value, *laneArg, loops, usesDynamicBinding)) - return std::nullopt; - } - else if (!isa(offset)) - return std::nullopt; - if (std::optional stride = getConstantIndex(strides[dim]); !stride || *stride != 1) - return std::nullopt; - std::optional size = getConstantIndex(sizes[dim]); - if (!size || *size != fragmentType.getDimSize(dim)) - return std::nullopt; - hasDynamicProjection = hasDynamicProjection || usesDynamicBinding; - } - - if (!hasDynamicProjection) - return fail("no-dynamic-projection"); - - for (int64_t dim = 0; dim < inputType.getRank(); ++dim) - if (fragmentType.getDimSize(dim) <= 0 || fragmentType.getDimSize(dim) > inputType.getDimSize(dim)) - return std::nullopt; - - AffineProjectedInputSliceMatch match; - match.extract = extract; - match.sourceType = inputType; - match.fragmentType = fragmentType; - match.offsets.assign(offsets.begin(), offsets.end()); - match.fragmentShape.assign(fragmentType.getShape().begin(), fragmentType.getShape().end()); - match.loops = std::move(loops); - return match; -} - -std::optional -getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex) { - ProjectedBatchInputKey key {batch.getOperation(), inputIndex}; - auto cached = state.projectedInputMatches.find(key); - if (cached != state.projectedInputMatches.end()) - return cached->second; - if (state.nonProjectedInputs.contains(key)) - return std::nullopt; - - std::optional match = matchAffineProjectedInputSlice(batch, inputIndex); - if (!match) { - state.nonProjectedInputs.insert(key); - return std::nullopt; - } - - state.projectedInputMatches.insert({key, *match}); - return match; -} - -FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); - -FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { - if (value == laneArg) - return static_cast(lane); - - if (std::optional constant = matchConstantIndexValue(value)) - return *constant; - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return failure(); - - SmallVector operands; - operands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); - if (failed(evaluated)) - return failure(); - operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); - } - - SmallVector results; - if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) - return failure(); - - auto intAttr = dyn_cast(results.front()); - if (!intAttr) - return failure(); - return intAttr.getInt(); -} - -FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { - if (auto attr = llvm::dyn_cast(value)) { - auto intAttr = dyn_cast(attr); - if (!intAttr) - return failure(); - return intAttr.getInt(); - } - return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); -} - -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return failure(); - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return failure(); - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (candidateIndex == resultIndex) - return insert; - } - - return failure(); -} - -FailureOr> -evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { - SmallVector evaluated; - evaluated.reserve(values.size()); - for (OpFoldResult value : values) { - FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); - if (failed(index)) - return failure(); - evaluated.push_back(*index); - } - return evaluated; -} - - -bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane) { - auto producerBatch = dyn_cast_or_null(producer.instance.op); - if (!producerBatch) - return true; - - FailureOr producerProjection = - getBatchResultProjectionInsert(producerBatch, producer.resultIndex); - if (failed(producerProjection)) - return true; - - std::optional producerLaneArg = producerBatch.getLaneArgument(); - std::optional consumerLaneArg = consumerBatch.getLaneArgument(); - if (!producerLaneArg || !consumerLaneArg) - return false; - - SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); - SmallVector loopIterationIndices(match.loops.size(), 0); - - const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { - SmallVector consumerOffsets; - consumerOffsets.reserve(match.offsets.size()); - for (OpFoldResult offset : match.offsets) { - FailureOr evaluated = - evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); - if (failed(evaluated)) - return false; - consumerOffsets.push_back(*evaluated); - } - - uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; - for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { - FailureOr> producerOffsets = - evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); - FailureOr> producerSizes = - evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); - FailureOr> producerStrides = - evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); - if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) - return false; - if (!areAllUnitStrides(*producerStrides)) - return false; - if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) - return true; - } - - return false; - }; - - if (match.loops.empty()) - return consumerSliceFitsOneProducerFragment(); - - const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { - if (loopIndex == match.loops.size()) - return consumerSliceFitsOneProducerFragment(); - - for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { - loopIterationIndices[loopIndex] = iteration; - if (!self(self, loopIndex + 1)) - return false; - } - return true; - }; - - return recurse(recurse, 0); -} - - -LogicalResult collectProjectedTransfers(MaterializerState& state) { - struct PendingProjectedTransferDescriptor { - ProjectedBatchInputKey inputKey; - Operation* extractOp = nullptr; - RankedTensorType sourceType; - RankedTensorType fragmentType; - SmallVector fragmentShape; - SmallVector, 16>, 8> fragmentOffsetsByLane; - SmallVector loopLowerBounds; - SmallVector loopSteps; - SmallVector loopTripCounts; - bool invalid = false; - }; - - DenseMap, ProducerKeyInfo> pending; - - const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) { - if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType) - return false; - - if (descriptor.fragmentOffsetsByLane.size() != 1) - return false; - - ArrayRef> fragments = descriptor.fragmentOffsetsByLane.front(); - if (fragments.size() != 1) - return false; - - return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; }); - }; - - const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor, - unsigned targetLane, - const AffineProjectedInputSliceMatch& match, - Value laneArg, - uint32_t lane) -> LogicalResult { - SmallVector loopIterationIndices; - loopIterationIndices.resize(match.loops.size(), 0); - - const auto appendOneFragment = [&]() -> LogicalResult { - SmallVector evaluatedOffsets; - evaluatedOffsets.reserve(match.offsets.size()); - for (OpFoldResult offset : match.offsets) { - FailureOr evaluated = - evaluateProjectedOffsetValue(offset, laneArg, lane, match.loops, loopIterationIndices); - if (failed(evaluated)) - return failure(); - evaluatedOffsets.push_back(*evaluated); - } - - if (!isStaticSliceInBounds(evaluatedOffsets, match.sourceType, match.fragmentType)) - return failure(); - - descriptor.fragmentOffsetsByLane[targetLane].push_back(std::move(evaluatedOffsets)); - return success(); - }; - - if (match.loops.empty()) - return appendOneFragment(); - - const auto recurse = [&](auto&& self, size_t loopIndex) -> LogicalResult { - if (loopIndex == match.loops.size()) - return appendOneFragment(); - - for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { - loopIterationIndices[loopIndex] = iteration; - if (failed(self(self, loopIndex + 1))) - return failure(); - } - return success(); - }; - - return recurse(recurse, 0); - }; - - if (failed(forEachLogicalConsumerInMaterializationOrder( - state, - [&](CpuId cpu, - ClassId targetClassId, - ComputeInstance consumer, - ComputeInstance logicalConsumer, - SlotId logicalSlot) -> LogicalResult { - auto batch = dyn_cast(consumer.op); - if (!batch) - return success(); - - MaterializedClass& targetClass = state.classes[targetClassId]; - unsigned targetLane = 0; - if (targetClass.isBatch) { - auto targetLaneIt = targetClass.cpuToLane.find(cpu); - if (targetLaneIt == targetClass.cpuToLane.end()) - return consumer.op->emitError("projected transfer collection could not recover target lane"); - targetLane = targetLaneIt->second; - } - - for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { - SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); - if (producers.size() != 1) - continue; - ProducerKey producer = producers.front(); - - ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producer.instance); - auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); - if (producerCpuIt == state.schedule.computeToCpuMap.end()) - continue; - - ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClassId == targetClassId) - continue; - - std::optional match = - getProjectedInputSliceMatch(state, batch, static_cast(inputIndex)); - if (!match) - continue; - if (!isProjectedInputSliceCompatibleWithProducerFragments( - batch, *match, producer, logicalConsumer.laneStart)) - continue; - - PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; - if (descriptor.fragmentOffsetsByLane.empty()) { - descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; - descriptor.extractOp = match->extract.getOperation(); - descriptor.sourceType = match->sourceType; - descriptor.fragmentType = match->fragmentType; - descriptor.fragmentShape = match->fragmentShape; - descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1); - descriptor.loopLowerBounds.reserve(match->loops.size()); - descriptor.loopSteps.reserve(match->loops.size()); - descriptor.loopTripCounts.reserve(match->loops.size()); - for (const StaticProjectedLoopInfo& loop : match->loops) { - descriptor.loopLowerBounds.push_back(loop.lowerBound); - descriptor.loopSteps.push_back(loop.step); - descriptor.loopTripCounts.push_back(loop.tripCount); - } - } - - ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; - if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation() - || descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType - || descriptor.fragmentShape != match->fragmentShape - || descriptor.loopLowerBounds.size() != match->loops.size()) { - descriptor.invalid = true; - continue; - } - for (auto [index, loop] : llvm::enumerate(match->loops)) { - if (descriptor.loopLowerBounds[index] != loop.lowerBound || descriptor.loopSteps[index] != loop.step - || descriptor.loopTripCounts[index] != loop.tripCount) { - descriptor.invalid = true; - break; - } - } - if (descriptor.invalid) - continue; - - if (targetLane >= descriptor.fragmentOffsetsByLane.size()) { - descriptor.invalid = true; - continue; - } - - if (failed(appendEvaluatedFragments( - descriptor, targetLane, *match, *batch.getLaneArgument(), logicalConsumer.laneStart))) { - descriptor.invalid = true; - continue; - } - - (void) logicalSlot; - } - - return success(); - }))) - return failure(); - - for (auto& producerEntry : pending) { - ProducerKey producer = producerEntry.first; - for (auto& classEntry : producerEntry.second) { - ClassId targetClassId = classEntry.first; - PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second; - - if (pendingDescriptor.invalid) - continue; - if (pendingDescriptor.fragmentOffsetsByLane.empty()) - continue; - if (isIdentityProjectedTransfer(pendingDescriptor)) - continue; - - MaterializedClass& targetClass = state.classes[targetClassId]; - ProjectedTransferDescriptor descriptor; - descriptor.inputKey = pendingDescriptor.inputKey; - descriptor.extractOp = pendingDescriptor.extractOp; - descriptor.layout.fragmentType = pendingDescriptor.fragmentType; - descriptor.layout.fragmentShape = pendingDescriptor.fragmentShape; - descriptor.layout.loopLowerBounds = pendingDescriptor.loopLowerBounds; - descriptor.layout.loopSteps = pendingDescriptor.loopSteps; - descriptor.layout.loopTripCounts = pendingDescriptor.loopTripCounts; - descriptor.layout.fragmentsPerLogicalSlot = getProjectedFragmentsPerLogicalSlot(descriptor.layout.loopTripCounts); - if (targetClass.isBatch) { - unsigned payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); - if (payloadFragmentCount == 0) - continue; - - // Batch-target projected replacements currently select fragments with the - // local materialization-run slot index. That is only unambiguous when each - // target lane receives one projected fragment. Multi-fragment payloads - // need an explicit producer-key to payload-slot mapping; otherwise two - // independently materialized runs can both select fragment 0 from the same - // packed receive and duplicate rows. - if (payloadFragmentCount != 1) - continue; - - bool uniform = true; - for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) { - if (laneFragments.size() != payloadFragmentCount) { - uniform = false; - break; - } - } - if (!uniform) - continue; - - descriptor.layout.payloadFragmentCount = payloadFragmentCount; - descriptor.fragmentOffsets.reserve(pendingDescriptor.fragmentOffsetsByLane.size() * payloadFragmentCount); - for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) - llvm::append_range(descriptor.fragmentOffsets, laneFragments); - } - else { - if (pendingDescriptor.fragmentOffsetsByLane.size() != 1) - return targetClass.op->emitError("scalar projected transfer descriptor expected one local offset stream"); - if (pendingDescriptor.fragmentOffsetsByLane.front().empty()) - continue; - - descriptor.layout.payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); - llvm::append_range(descriptor.fragmentOffsets, pendingDescriptor.fragmentOffsetsByLane.front()); - if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) - return targetClass.op->emitError("scalar projected transfer offset count does not match the local run"); - } - if (failed(finalizeProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - - state.projectedTransfers[producer][targetClassId] = std::move(descriptor); - } - } - - return success(); -} - -static std::optional -collectScalarTargetProjectedDescriptor(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef keys, - bool requirePackedRunOffsetCountMatch) { - assert(!targetClass.isBatch && "scalar target projected descriptor helper expects a scalar class"); - - std::optional combined; - for (ProducerKey key : keys) { - auto producerIt = state.projectedTransfers.find(key); - if (producerIt == state.projectedTransfers.end()) - return std::nullopt; - - auto descriptorIt = producerIt->second.find(targetClass.id); - if (descriptorIt == producerIt->second.end()) - return std::nullopt; - - const ProjectedTransferDescriptor& descriptor = descriptorIt->second; - if (descriptor.fragmentOffsets.empty()) - return std::nullopt; - if (descriptor.layout.payloadFragmentCount == 0 || descriptor.layout.fragmentsPerLogicalSlot == 0) - return std::nullopt; - if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) - return std::nullopt; - if (descriptor.layout.payloadFragmentCount % descriptor.layout.fragmentsPerLogicalSlot != 0) - return std::nullopt; - - if (!combined) { - combined = descriptor; - continue; - } - - if (!(combined->inputKey == descriptor.inputKey) || combined->extractOp != descriptor.extractOp - || combined->layout.fragmentType != descriptor.layout.fragmentType - || combined->layout.fragmentShape != descriptor.layout.fragmentShape - || combined->layout.loopLowerBounds != descriptor.layout.loopLowerBounds - || combined->layout.loopSteps != descriptor.layout.loopSteps - || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts - || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot) - return std::nullopt; - - combined->layout.payloadFragmentCount += descriptor.layout.payloadFragmentCount; - llvm::append_range(combined->fragmentOffsets, descriptor.fragmentOffsets); - } - - if (!combined) - return std::nullopt; - - if (combined->fragmentOffsets.size() != combined->layout.payloadFragmentCount) - return std::nullopt; - - if (requirePackedRunOffsetCountMatch) { - if (combined->layout.payloadFragmentCount != keys.size() * combined->layout.fragmentsPerLogicalSlot) - return std::nullopt; - } - if (failed(finalizeProjectedTransferDescriptor(targetClass.op, *combined))) - return std::nullopt; - return combined; -} - -bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { - if (keys.empty()) - return true; - - auto firstIt = state.producerDestClasses.find(keys.front()); - ArrayRef first = firstIt == state.producerDestClasses.end() ? ArrayRef() : firstIt->second; - for (ProducerKey key : keys.drop_front()) { - auto it = state.producerDestClasses.find(key); - ArrayRef current = it == state.producerDestClasses.end() ? ArrayRef() : it->second; - if (first.size() != current.size()) - return false; - for (auto [lhs, rhs] : llvm::zip(first, current)) - if (lhs != rhs) - return false; - } - return true; -} - -ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey key) { - auto it = state.producerDestClasses.find(key); - if (it == state.producerDestClasses.end()) - return {}; - return it->second; -} - -// ----------------------------------------------------------------------------- -// Communication materialization helpers. -// ----------------------------------------------------------------------------- - -void appendScalarSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { - assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); - Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); - Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); - SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); -} - -LogicalResult appendScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); - assert(messages.size() > 1 && "send loop is only useful for multiple sends"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - - auto sendLoop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - return success(); - }); - if (failed(sendLoop)) - return failure(); - return success(); -} - -FailureOr buildProjectedPackedPayload(MaterializerState& state, - MaterializedClass& targetClass, - Value fullPayload, - const ProjectedTransferDescriptor& descriptor, - Value messageIndex, - Location loc) { - if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - if (descriptor.layout.payloadFragmentCount == 1) - return targetClass.op->emitError("projected packed payload builder expects a packed payload"); - - FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fullPayload, - targetClass.op, - "projected packed payload tried to reuse a tensor from another materialized class"); - if (failed(localizedPayload)) - return failure(); - FailureOr localizedMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); - if (failed(localizedMessageIndex)) - return failure(); - - Value init = tensor::EmptyOp::create( - state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) - .getResult(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {init}, - [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value payloadFragmentCount = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localizedMessageIndex, payloadFragmentCount).getResult(); - Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - - FailureOr> fragmentOffsets = - buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc); - if (failed(fragmentOffsets)) - return failure(); - FailureOr fragment = createStaticExtractSliceInClass( - state, targetClass, loc, *localizedPayload, *fragmentOffsets, descriptor.layout.fragmentShape); - if (failed(fragment)) - return failure(); - - FailureOr packedOffset = - scaleIndexByDim0SizeInClass(state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); - if (failed(packedOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -FailureOr buildProjectedPayloadForMessage(MaterializerState& state, - MaterializedClass& targetClass, - Value fullPayload, - const ProjectedTransferDescriptor& descriptor, - Value messageIndex, - Location loc) { - if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - - FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fullPayload, - targetClass.op, - "projected payload tried to reuse a tensor from another materialized class"); - if (failed(localizedPayload)) - return failure(); - - if (descriptor.layout.payloadFragmentCount == 1) { - FailureOr> fragmentOffsets = - buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, messageIndex, loc); - if (failed(fragmentOffsets)) - return failure(); - return createStaticExtractSliceInClass( - state, targetClass, loc, *localizedPayload, *fragmentOffsets, descriptor.layout.fragmentShape); - } - - return buildProjectedPackedPayload(state, targetClass, *localizedPayload, descriptor, messageIndex, loc); -} - -LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ProjectedTransferDescriptor& descriptor, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - if (failed(verifyProjectedSendDescriptor(sourceClass.op, descriptor, messages))) - return failure(); - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - - if (messages.size() == 1) { - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); - Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass, payload, descriptor, messageIndex, loc); - if (failed(sendPayload)) - return failure(); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); - return success(); - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - - auto projectedSendLoop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); - FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass, payload, descriptor, index, loc); - if (failed(sendPayload)) - return failure(); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); - return success(); - }); - if (failed(projectedSendLoop)) - return failure(); - return success(); -} - -LogicalResult appendSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one send"); - - if (sourceClass.isBatch) { - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - - Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); - Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); - Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - return success(); - } - - if (messages.size() == 1) { - appendScalarSend(state, - sourceClass, - payload, - messages.channelIds.front(), - messages.sourceCoreIds.front(), - messages.targetCoreIds.front(), - loc); - return success(); - } - - return appendScalarSendLoop(state, sourceClass, payload, messages, loc); -} - -Value appendScalarReceive(MaterializerState& state, - MaterializedClass& targetClass, - Type type, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { - assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); - Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); - Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); - return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) - .getOutput(); -} - -Value appendReceive( - MaterializerState& state, MaterializedClass& targetClass, Type type, const MessageVector& messages, Location loc) { - assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one receive"); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - if (targetClass.isBatch) { - Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); - Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); - Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); - return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput(); - } - - assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); - return appendScalarReceive(state, - targetClass, - type, - messages.channelIds.front(), - messages.sourceCoreIds.front(), - messages.targetCoreIds.front(), - loc); -} - -LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - ArrayRef keys, - Type fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds) { - if (!sourceClass.isBatch) - return sourceClass.op->emitError("lazy packed scalar receives expect a batch source class"); - - if (targetClass.isBatch) - return targetClass.op->emitError("lazy packed scalar receives expect a scalar target class"); - - if (keys.empty()) - return sourceClass.op->emitError("lazy packed scalar receive expects at least one producer key"); - - if (keys.size() != sourceClass.cpus.size()) - return sourceClass.op->emitError("lazy packed scalar receive expects one producer key per source lane"); - - MessageVector messages; - messages.append(channelIds, sourceCoreIds, targetCoreIds); - if (failed(messages.verify(targetClass.op))) - return failure(); - - if (keys.size() != messages.size()) - return targetClass.op->emitError("lazy packed scalar receive metadata is inconsistent"); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return targetClass.op->emitError("lazy packed scalar receive expects a static ranked fragment type"); - - if (failed(verifyPackableFragmentType( - targetClass.op, fragmentType, keys.size(), "cannot create lazy packed scalar receive type"))) - return failure(); - - Operation* sourceOp = keys.front().instance.op; - size_t resultIndex = keys.front().resultIndex; - - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != resultIndex) - return sourceClass.op->emitError("lazy packed scalar receive expects one producer result"); - - if (key.instance.laneCount != 1) - return sourceClass.op->emitError("lazy packed scalar receive expects one lane per producer key"); - } - - PackedScalarRunValue packedRun; - packedRun.targetClass = targetClass.id; - packedRun.sourceOp = sourceOp; - packedRun.resultIndex = resultIndex; - packedRun.kind = PackedScalarRunKind::DeferredReceive; - packedRun.fragmentType = rankedFragmentType; - - packedRun.messages = std::move(messages); - - PackedScalarRunSlot slot; - llvm::append_range(slot.keys, keys); - packedRun.slots.push_back(std::move(slot)); - - if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) - return failure(); - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -struct ScalarSourceReceivePlan { - ClassId targetClass = 0; - MessageVector messages; - Type receiveType; - Operation* projectedExtractOp = nullptr; - ProjectedFragmentLayout projectedLayout; -}; - -struct ProjectedScalarSendGroup { - MessageVector messages; - ProjectedTransferDescriptor descriptor; -}; - -struct ScalarSourceFanoutPlan { - SmallVector receivePlans; - std::optional ordinaryMessages; - SmallVector projectedSendGroups; -}; - -bool hasSameProjectedSendCompatibility(const ProjectedTransferDescriptor& lhs, const ProjectedTransferDescriptor& rhs) { - return lhs.layout.fragmentType == rhs.layout.fragmentType && lhs.layout.fragmentShape == rhs.layout.fragmentShape - && lhs.layout.fragmentsPerLogicalSlot == rhs.layout.fragmentsPerLogicalSlot - && lhs.layout.payloadFragmentCount == rhs.layout.payloadFragmentCount - && lhs.layout.loopLowerBounds == rhs.layout.loopLowerBounds && lhs.layout.loopSteps == rhs.layout.loopSteps - && lhs.layout.loopTripCounts == rhs.layout.loopTripCounts && lhs.payloadType == rhs.payloadType; -} - -SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { - SmallVector destinations; - - for (ProducerKey key : keys) - for (ClassId destinationClass : getDestinationClasses(state, key)) - destinations.push_back(destinationClass); - - llvm::sort(destinations); - destinations.erase(std::unique(destinations.begin(), destinations.end()), destinations.end()); - return destinations; -} - -FailureOr buildScalarSourceFanoutPlan(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - ArrayRef destinationClasses, - Value payload) { - assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); - - auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id"); - if (failed(sourceCpu)) - return failure(); - - ScalarSourceFanoutPlan fanoutPlan; - fanoutPlan.receivePlans.reserve(destinationClasses.size()); - - const auto getProjectedDescriptor = - [&](ClassId destinationClass) -> FailureOr> { - MaterializedClass& targetClass = state.classes[destinationClass]; - if (!targetClass.isBatch) { - bool hasAnyProjectedDescriptor = llvm::any_of(keys, [&](ProducerKey key) { - auto producerIt = state.projectedTransfers.find(key); - return producerIt != state.projectedTransfers.end() && producerIt->second.count(destinationClass) != 0; - }); - - std::optional descriptor = collectScalarTargetProjectedDescriptor( - state, targetClass, keys, /*requirePackedRunOffsetCountMatch=*/keys.size() > 1); - if (hasAnyProjectedDescriptor && !descriptor) - return targetClass.op->emitError("incomplete scalar projected transfer descriptor for local run"); - return descriptor; - } - - if (keys.size() != 1) - return std::optional {}; - - auto producerIt = state.projectedTransfers.find(keys.front()); - if (producerIt == state.projectedTransfers.end()) - return std::optional {}; - - auto descriptorIt = producerIt->second.find(destinationClass); - if (descriptorIt == producerIt->second.end()) - return std::optional {}; - - const ProjectedTransferDescriptor& descriptor = descriptorIt->second; - if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) - return failure(); - if (descriptor.fragmentOffsets.size() - != targetClass.cpus.size() * static_cast(descriptor.layout.payloadFragmentCount)) - return targetClass.op->emitError("inconsistent batch projected transfer descriptor"); - - return std::optional {descriptor}; - }; - - for (ClassId destinationClass : destinationClasses) { - if (destinationClass == sourceClass.id) - continue; - - MaterializedClass& targetClass = state.classes[destinationClass]; - - ScalarSourceReceivePlan receivePlan; - receivePlan.targetClass = destinationClass; - receivePlan.receiveType = payload.getType(); - - auto appendMessage = [&](CpuId targetCpu) -> LogicalResult { - auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id"); - if (failed(checkedTargetCpu)) - return failure(); - int64_t channelId = state.nextChannelId++; - - receivePlan.messages.append(channelId, *sourceCpu, *checkedTargetCpu); - return success(); - }; - - if (!targetClass.isBatch) { - if (failed(appendMessage(targetClass.cpus.front()))) - return failure(); - } - else { - for (CpuId targetCpu : targetClass.cpus) - if (failed(appendMessage(targetCpu))) - return failure(); - } - - FailureOr> descriptor = getProjectedDescriptor(destinationClass); - if (failed(descriptor)) - return failure(); - - if (*descriptor) { - const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; - - if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) - return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type"); - - receivePlan.receiveType = projectedDescriptor.payloadType; - receivePlan.projectedExtractOp = projectedDescriptor.extractOp; - receivePlan.projectedLayout = projectedDescriptor.layout; - - auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { - return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); - }); - if (groupIt == fanoutPlan.projectedSendGroups.end()) { - ProjectedScalarSendGroup group; - group.descriptor.layout = projectedDescriptor.layout; - group.descriptor.payloadType = projectedDescriptor.payloadType; - fanoutPlan.projectedSendGroups.push_back(std::move(group)); - groupIt = std::prev(fanoutPlan.projectedSendGroups.end()); - } - - groupIt->messages.append( - receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - llvm::append_range(groupIt->descriptor.fragmentOffsets, projectedDescriptor.fragmentOffsets); - } - else { - if (!fanoutPlan.ordinaryMessages) - fanoutPlan.ordinaryMessages = MessageVector {}; - fanoutPlan.ordinaryMessages->append( - receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - } - - fanoutPlan.receivePlans.push_back(std::move(receivePlan)); - } - - for (ProjectedScalarSendGroup& group : fanoutPlan.projectedSendGroups) { - if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, group.descriptor))) - return failure(); - if (failed(verifyProjectedSendDescriptor(sourceClass.op, group.descriptor, group.messages))) - return failure(); - } - - return fanoutPlan; -} - -LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ScalarSourceFanoutPlan& plan, - Location loc) { - if (plan.ordinaryMessages && failed(appendSend(state, sourceClass, payload, *plan.ordinaryMessages, loc))) - return failure(); - - for (const ProjectedScalarSendGroup& group : plan.projectedSendGroups) - if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, group.descriptor, group.messages, loc))) - return failure(); - - return success(); -} - -LogicalResult emitScalarSourceCommunication( - MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { - assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); - - for (ProducerKey key : keys) - state.availableValues.record(key, sourceClass.id, payload); - - SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); - auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); - if (failed(fanoutPlan)) - return failure(); - if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) - return failure(); - - for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { - MaterializedClass& targetClass = state.classes[plan.targetClass]; - - Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); - - if (plan.projectedExtractOp) { - state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = - ProjectedExtractReplacement {received, plan.projectedLayout}; - continue; - } - - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); - } - - return success(); -} - -LogicalResult emitClassToClassCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - ArrayRef keys, - Value payload, - Location loc) { - if (sourceClass.id == targetClass.id) { - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, payload); - return success(); - } - - if (!sourceClass.isBatch) - return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); - - if (!targetClass.isBatch) { - if (keys.size() != sourceClass.cpus.size()) - return sourceClass.op->emitError( - "cannot materialize batch-to-scalar communication without one producer key per source lane") - << " keyCount=" << keys.size() << " laneCount=" << sourceClass.cpus.size(); - - Operation* sourceOp = keys.front().instance.op; - size_t sourceResultIndex = keys.front().resultIndex; - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != sourceResultIndex || key.instance.laneCount != 1) - return sourceClass.op->emitError( - "cannot materialize batch-to-scalar communication for incompatible source keys"); - } - - MessageVector messages; - messages.channelIds.reserve(sourceClass.cpus.size()); - messages.sourceCoreIds.reserve(sourceClass.cpus.size()); - messages.targetCoreIds.reserve(sourceClass.cpus.size()); - - auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); - if (failed(targetCpu)) - return failure(); - for (CpuId sourceCpu : sourceClass.cpus) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); - if (failed(checkedSourceCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *targetCpu); - } - - if (failed(appendSend(state, sourceClass, payload, messages, loc))) - return failure(); - return registerLazyPackedScalarReceives(state, - sourceClass, - targetClass, - keys, - payload.getType(), - messages.channelIds, - messages.sourceCoreIds, - messages.targetCoreIds); - } - - if (sourceClass.cpus.size() != targetClass.cpus.size()) - return sourceClass.op->emitError( - "cannot materialize batch communication between equivalence classes of different sizes"); - - MessageVector messages; - messages.channelIds.reserve(sourceClass.cpus.size()); - messages.sourceCoreIds.reserve(sourceClass.cpus.size()); - messages.targetCoreIds.reserve(targetClass.cpus.size()); - - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id"); - if (failed(checkedSourceCpu)) - return failure(); - auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id"); - if (failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - } - - if (failed(appendSend(state, sourceClass, payload, messages, loc))) - return failure(); - Value received = appendReceive(state, targetClass, payload.getType(), messages, loc); - - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); - - return success(); -} - -FailureOr recordProjectedBatchHostFragmentsFromBatchValue(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& ownerClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc); - -LogicalResult -setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { - auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); - if (resultIt == sourceClass.hostOutputToResultIndex.end()) - return sourceClass.op->emitError("missing host result slot for materialized output") - << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") - << " hostOutputs=" << sourceClass.hostOutputs.size() - << " originalDef=" << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() - : StringRef("")); - - unsigned resultIndex = resultIt->second; - state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); - - if (!sourceClass.isBatch) { - auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); - if (!yieldOp) - return sourceClass.op->emitError("expected spat.yield terminator in materialized compute"); - if (resultIndex >= yieldOp.getNumOperands()) - return sourceClass.op->emitError("host result index out of range for materialized compute"); - if (payload.getType() != originalOutput.getType()) - return sourceClass.op->emitError("cannot set scalar host output from fragment payload") - << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); - - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); - return success(); - } - - auto batch = cast(sourceClass.op); - auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); - if (!inParallelOp) - return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape()) - return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); - - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); - - auto outputArg = batch.getOutputArgument(resultIndex); - if (!outputArg) - return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); - - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); - return success(); -} - -LogicalResult -emitHostCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - Value originalOutput, - ArrayRef keys = {}) { - if (!hasLiveExternalUseCached(state, originalOutput)) - return success(); - - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for live external output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (sourceClass.id == ownerClass.id) - return setHostOutputValue(state, ownerClass, originalOutput, payload); - - // Keep the old deadlock-free communication discipline: only scalar-to-scalar - // host-owner forwarding is introduced here. Batch host publication remains on - // the owning batch path; projected terminal batch publication must use the - // explicit projected whole-batch path instead of generic host forwarding. - if (sourceClass.isBatch) { - FailureOr recordedProjectedHostFragments = recordProjectedBatchHostFragmentsFromBatchValue( - state, sourceClass, ownerClass, keys, payload, originalOutput, payload.getLoc()); - if (failed(recordedProjectedHostFragments)) - return failure(); - if (*recordedProjectedHostFragments) - return success(); - return sourceClass.op->emitError("batch host publication must be routed through the owning/projection-aware path"); - } - if (ownerClass.isBatch) - return ownerClass.op->emitError("generic host publication does not support batch host owners"); - if (payload.getType() != originalOutput.getType()) - return sourceClass.op->emitError("cannot forward fragment payload to scalar host owner") - << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); - - MessageVector messages; - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "host target core id"); - if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - - if (failed(appendSend(state, sourceClass, payload, messages, payload.getLoc()))) - return failure(); - Value ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, payload.getLoc()); - return setHostOutputValue(state, ownerClass, originalOutput, ownerPayload); -} - -LogicalResult emitOutputFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { - if (keys.empty()) - return success(); - - if (!sourceClass.isBatch) { - if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) - return failure(); - - return emitHostCommunication(state, sourceClass, payload, originalOutput); - } - - if (!haveSameDestinationClasses(state, keys)) - return sourceClass.op->emitError( - "cannot materialize batched output whose lanes have different destination equivalence classes"); - - for (ClassId destinationClass : getDestinationClasses(state, keys.front())) - if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) - return failure(); - - if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput, keys))) - return failure(); - - for (ProducerKey key : keys) - state.availableValues.record(key, sourceClass.id, payload); - - return success(); -} - -struct DirectWholeBatchFragment { - ProducerKey key; - Value fragment; -}; - -enum class WholeBatchFragmentSourceKind { - DeferredReceive, - DeferredLocalCompute, - PackedValue, - DirectValue -}; - -struct WholeBatchFragmentGroup { - WholeBatchFragmentSourceKind kind = WholeBatchFragmentSourceKind::DirectValue; - RankedTensorType fragmentType; - SmallVector outputOffsets; - MessageVector messages; - Operation* sourceOp = nullptr; - size_t resultIndex = 0; - SmallVector sourceLanes; - Value packed; - RankedTensorType slotPackedType; - SmallVector slotIndices; - SmallVector, 16> directFragments; - SmallVector redundantReceives; -}; - -enum class ProjectedWholeBatchFragmentSourceKind { - DeferredReceive, - PackedValue, - DirectValue -}; - -struct ProjectedWholeBatchDirectFragment { - Value fragment; - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -struct ProjectedWholeBatchFragmentGroup { - ProjectedWholeBatchFragmentSourceKind kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; - RankedTensorType fragmentType; - SmallVector, 4> offsetsByDim; - SmallVector, 4> sizesByDim; - SmallVector, 4> stridesByDim; - MessageVector messages; - SmallVector redundantOps; - Value packed; - RankedTensorType packedSourceType; - SmallVector packedIndices; - SmallVector directFragments; -}; - -struct WholeBatchAssemblyPlan { - RankedTensorType resultType; - int64_t rowsPerLane = 0; - uint32_t batchLaneCount = 0; - uint32_t coveredLaneCount = 0; - - SmallVector coveredLanes; - SmallVector packedRuns; - SmallVector directFragments; -}; - -bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) { - return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0; -} - -bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { - if (laneCount == 0) - return false; - if (laneStart >= plan.coveredLanes.size()) - return false; - - uint32_t laneEnd = std::min(laneStart + laneCount, plan.coveredLanes.size()); - for (uint32_t lane = laneStart; lane < laneEnd; ++lane) - if (plan.coveredLanes[lane] != 0) - return true; - return false; -} - -void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { - assert(laneCount != 0 && "cannot cover an empty whole-batch range"); - assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds"); - - for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) { - if (plan.coveredLanes[lane] != 0) - continue; - plan.coveredLanes[lane] = 1; - ++plan.coveredLaneCount; - } -} - -bool localLaneRangeOverlaps(ArrayRef covered, uint32_t laneStart, uint32_t laneCount) { - if (laneCount == 0) - return false; - if (laneStart >= covered.size()) - return false; - - uint32_t laneEnd = std::min(laneStart + laneCount, covered.size()); - for (uint32_t lane = laneStart; lane < laneEnd; ++lane) - if (covered[lane] != 0) - return true; - return false; -} - -void markLocalLaneRangeCovered(MutableArrayRef covered, uint32_t laneStart, uint32_t laneCount) { - assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds"); - for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) - covered[lane] = 1; -} - -LogicalResult -validateWholeBatchFragmentType(RankedTensorType resultType, RankedTensorType fragmentType, int64_t expectedRows) { - if (!fragmentType.hasStaticShape()) - return failure(); - if (fragmentType.getRank() != resultType.getRank()) - return failure(); - if (fragmentType.getDimSize(0) != expectedRows) - return failure(); - - for (int64_t dim = 1; dim < resultType.getRank(); ++dim) - if (fragmentType.getDimSize(dim) != resultType.getDimSize(dim)) - return failure(); - - return success(); -} - -// ----------------------------------------------------------------------------- -// Packed run tensor assembly helpers. -// ----------------------------------------------------------------------------- - -FailureOr insertFragmentIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value fragment, - Value destination, - OpFoldResult firstOffset, - Location loc) { - return createDim0InsertSliceInClass(state, targetClass, loc, fragment, destination, firstOffset); -} - -FailureOr extractPackedSlotForIndex(MaterializerState& state, - MaterializedClass& targetClass, - Value packed, - RankedTensorType slotPackedType, - Value slotIndex, - Location loc) { - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0)); -} - -SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { - SmallVector keys; - for (const PackedScalarRunSlot& slot : run.slots) - llvm::append_range(keys, slot.keys); - return keys; -} - -bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScalarRunValue& rhs) { - if (lhs.slots.size() != rhs.slots.size()) - return false; - - for (auto [lhsSlot, rhsSlot] : llvm::zip(lhs.slots, rhs.slots)) { - if (lhsSlot.keys.size() != rhsSlot.keys.size()) - return false; - if (!llvm::equal(lhsSlot.keys, rhsSlot.keys)) - return false; - } - - return true; -} - - -std::optional getConstantIndexValue(Value value) { - APInt constant; - if (matchPattern(value, m_ConstantInt(&constant))) - return constant.getSExtValue(); - return std::nullopt; -} - -bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelReceiveOp receive) { - std::optional channelId = getConstantIndexValue(receive.getChannelId()); - std::optional sourceCoreId = getConstantIndexValue(receive.getSourceCoreId()); - std::optional targetCoreId = getConstantIndexValue(receive.getTargetCoreId()); - if (!channelId || !sourceCoreId || !targetCoreId) - return false; - messages.append(*channelId, static_cast(*sourceCoreId), static_cast(*targetCoreId)); - return true; -} - -PackedScalarRunValue* findDeferredReceiveAlternativeForPackedRun(MaterializerState& state, - const MaterializedClass& targetClass, - const PackedScalarRunValue& run) { - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); - ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - - for (size_t runIndex : runIndices) { - PackedScalarRunValue& candidate = state.availableValues.getPackedRun(runIndex); - if (&candidate == &run || candidate.kind != PackedScalarRunKind::DeferredReceive) - continue; - if (candidate.fragmentType != run.fragmentType) - continue; - if (!packedScalarRunSlotsMatch(candidate, run)) - continue; - return &candidate; - } - - return nullptr; -} - -FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - int64_t itemCount, - IndexedFragmentBuilder buildFragment, - IndexedInsertOffsetBuilder buildOffset, - Location loc) { - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, itemCount); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - Operation* insertionPoint = targetClass.body->getTerminator(); - - state.rewriter.setInsertionPoint(insertionPoint); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - FailureOr fragment = buildFragment(flatIndex); - if (failed(fragment)) - return failure(); - FailureOr offset = buildOffset(flatIndex); - if (failed(offset)) - return failure(); - FailureOr next = - insertFragmentIntoWholeBatch(state, targetClass, *fragment, iterArgs.front(), *offset, loc); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -FailureOr> cloneBatchBodyForLane(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - Value laneValue, - ArrayRef resultIndices, - CloneIndexingContext indexing = {}); - -Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc); -FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, - MaterializedClass& targetClass, - IndexedBatchRunValue& run, - Value runSlotIndex, - Location loc); - -FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, - MaterializedClass& targetClass, - PackedScalarRunValue& run, - Location loc) { - assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); - - SmallVector keys = flattenPackedScalarRunKeys(run); - if (keys.empty()) - return failure(); - FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); - if (failed(packedType)) - return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); - - SmallVector sourceLanes; - sourceLanes.reserve(keys.size()); - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return failure(); - sourceLanes.push_back(key.instance.laneStart); - } - - SmallVector resultIndices {run.resultIndex}; - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {init}, - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - keys.front().instance, - sourceLane, - resultIndices, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); - if (failed(produced) || produced->size() != 1) - return failure(); - - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); - return success(); - }); - if (failed(loop)) - return failure(); - run.packed = loop->results.front(); - return run.packed; -} - -LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - WholeBatchAssemblyPlan& plan) { - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); - ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - - for (size_t runIndex : runIndices) { - PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); - - SmallVector runKeys; - SmallVector runCoveredLanes(plan.batchLaneCount, 0); - - for (const PackedScalarRunSlot& slot : run.slots) { - for (ProducerKey fragmentKey : slot.keys) { - if (fragmentKey.instance.op != key.instance.op || fragmentKey.resultIndex != key.resultIndex) - return failure(); - - if (fragmentKey.instance.laneCount == 0) - return failure(); - - if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) - return failure(); - - if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) - return failure(); - - markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount); - runKeys.push_back(fragmentKey); - } - } - - if (runKeys.empty()) - continue; - - plan.packedRuns.push_back(&run); - - for (ProducerKey runKey : runKeys) - recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount); - } - - return success(); -} - -LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - SpatComputeBatch batch, - ProducerKey key, - WholeBatchAssemblyPlan& plan) { - struct CandidateFragment { - ProducerKey key; - Value value; - }; - - uint32_t batchLaneCount = static_cast(batch.getLaneCount()); - if (plan.coveredLaneCount == plan.batchLaneCount) { - return success(); - } - - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); - ArrayRef indexedFragments = - state.availableValues.getExactFragmentsForWholeBatch(lookupKey); - - SmallVector candidates; - candidates.reserve(indexedFragments.size()); - for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) { - ProducerKey candidateKey = record.key; - if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex - || candidateKey.instance.laneCount == 0) - continue; - if (!isTensorValueLocalToMaterializedClass(record.value, targetClass)) - continue; - if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount)) - continue; - - auto fragmentType = dyn_cast(record.value.getType()); - if (!fragmentType) - continue; - - int64_t expectedRows = plan.rowsPerLane * static_cast(candidateKey.instance.laneCount); - if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) - continue; - - candidates.push_back({candidateKey, record.value}); - } - - llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) { - if (lhs.key.instance.laneStart != rhs.key.instance.laneStart) - return lhs.key.instance.laneStart < rhs.key.instance.laneStart; - return lhs.key.instance.laneCount > rhs.key.instance.laneCount; - }); - - size_t candidateCursor = 0; - uint32_t lane = 0; - while (lane < batchLaneCount) { - while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) { - ++lane; - } - - if (lane >= batchLaneCount) - break; - - while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane) - ++candidateCursor; - - size_t candidateIndex = candidateCursor; - const CandidateFragment* best = nullptr; - while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) { - const CandidateFragment& candidate = candidates[candidateIndex]; - if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) { - best = &candidate; - break; - } - ++candidateIndex; - } - - if (!best) - return failure(); - - plan.directFragments.push_back({best->key, best->value}); - recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount); - lane += best->key.instance.laneCount; - } - - return success(); -} - -LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, - MaterializedClass& targetClass, - const WholeBatchAssemblyPlan& plan, - SmallVectorImpl& groups) { - for (PackedScalarRunValue* run : plan.packedRuns) { - if (!run || run->slots.empty()) - continue; - if (run->fragmentType.getDimSize(0) != plan.rowsPerLane) - return failure(); - - if (run->kind == PackedScalarRunKind::Materialized && run->packed - && !isTensorValueLocalToMaterializedClass(run->packed, targetClass)) { - if (PackedScalarRunValue* deferredRun = findDeferredReceiveAlternativeForPackedRun(state, targetClass, *run)) - run = deferredRun; - else { - SmallVector keys = flattenPackedScalarRunKeys(*run); - std::optional packedKey = getContiguousProducerRangeForKeys(keys); - emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, - targetClass, - "whole-batch assembly tried to reuse non-local PackedValue", - run->packed, - packedKey); - return failure(); - } - } - - if (run->kind == PackedScalarRunKind::DeferredReceive) { - if (failed(validatePackedScalarRunMetadata(targetClass.op, *run))) - return failure(); - - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == run->fragmentType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DeferredReceive; - group.fragmentType = run->fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - - groupIt->messages.append(run->messages.channelIds, run->messages.sourceCoreIds, run->messages.targetCoreIds); - for (const PackedScalarRunSlot& slot : run->slots) - for (ProducerKey fragmentKey : slot.keys) - groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); - continue; - } - - if (run->kind == PackedScalarRunKind::DeferredLocalCompute) { - SmallVector keys = flattenPackedScalarRunKeys(*run); - if (keys.empty()) - return failure(); - - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DeferredLocalCompute - && group.fragmentType == run->fragmentType && group.sourceOp == run->sourceOp - && group.resultIndex == run->resultIndex; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DeferredLocalCompute; - group.fragmentType = run->fragmentType; - group.sourceOp = run->sourceOp; - group.resultIndex = run->resultIndex; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - - for (ProducerKey fragmentKey : keys) { - if (fragmentKey.instance.laneCount != 1) - return failure(); - groupIt->sourceLanes.push_back(fragmentKey.instance.laneStart); - groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); - } - continue; - } - - auto sourceBatch = dyn_cast_or_null(run->sourceOp); - if (!sourceBatch || !run->packed) - return failure(); - - auto getOrCreatePackedValueGroup = [&](RankedTensorType slotPackedType) -> WholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType - && group.packed == run->packed && group.slotPackedType == slotPackedType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::PackedValue; - group.fragmentType = run->fragmentType; - group.packed = run->packed; - group.slotPackedType = slotPackedType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - size_t flattenedIndexBase = 0; - for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { - std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); - if (contiguousKey) { - FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slot.keys.size()); - if (failed(slotPackedType)) - return failure(); - WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType); - group.slotIndices.push_back(slotIndex); - group.outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); - flattenedIndexBase += slot.keys.size(); - continue; - } - - WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType); - for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) { - group.slotIndices.push_back(flattenedIndexBase + keyIndex); - group.outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); - } - flattenedIndexBase += slot.keys.size(); - } - } - - auto getOrCreateDeferredReceiveGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DeferredReceive; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - auto getOrCreateDirectValueGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::DirectValue; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - for (const DirectWholeBatchFragment& fragment : plan.directFragments) { - if (!isTensorValueLocalToMaterializedClass(fragment.fragment, targetClass)) { - emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, - targetClass, - "whole-batch assembly tried to reuse non-local DirectValue", - fragment.fragment, - fragment.key); - return failure(); - } - - auto fragmentType = dyn_cast(fragment.fragment.getType()); - if (!fragmentType) - return failure(); - - int64_t outputOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; - - if (auto receive = fragment.fragment.getDefiningOp()) { - if (fragment.fragment.use_empty()) { - WholeBatchFragmentGroup& group = getOrCreateDeferredReceiveGroup(fragmentType); - if (appendConstantChannelReceiveMessage(group.messages, receive)) { - group.outputOffsets.push_back(outputOffset); - group.redundantReceives.push_back(receive.getOperation()); - continue; - } - } - } - - WholeBatchFragmentGroup& group = getOrCreateDirectValueGroup(fragmentType); - group.directFragments.push_back({fragment.fragment, outputOffset}); - } - - return success(); -} - -FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchFragmentGroup& group, - Location loc) { - switch (group.kind) { - case WholeBatchFragmentSourceKind::DeferredReceive: { - FailureOr updated = emitIndexedFragmentInsertLoop( - state, - targetClass, - destination, - static_cast(group.outputOffsets.size()), - [&](Value flatIndex) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); - return SpatChannelReceiveOp::create( - state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - [&](Value flatIndex) -> FailureOr { - return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); - }, - loc); - if (failed(updated)) - return failure(); - - for (Operation* receive : group.redundantReceives) - if (receive && receive->use_empty()) - receive->erase(); - - return *updated; - } - case WholeBatchFragmentSourceKind::DeferredLocalCompute: { - SmallVector resultIndices {group.resultIndex}; - return emitIndexedFragmentInsertLoop( - state, - targetClass, - destination, - static_cast(group.outputOffsets.size()), - [&](Value flatIndex) -> FailureOr { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, group.sourceLanes, flatIndex, loc); - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - ComputeInstance {group.sourceOp, 0, 1}, - sourceLane, - resultIndices, - CloneIndexingContext {.runSlotIndex = flatIndex, .projectionSlotIndex = flatIndex}); - if (failed(produced) || produced->size() != 1) - return failure(); - return produced->front(); - }, - [&](Value flatIndex) -> FailureOr { - return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); - }, - loc); - } - case WholeBatchFragmentSourceKind::PackedValue: - return emitIndexedFragmentInsertLoop( - state, - targetClass, - destination, - static_cast(group.slotIndices.size()), - [&](Value flatIndex) -> FailureOr { - Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); - FailureOr packed = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - group.packed, - targetClass.op, - "whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); - if (failed(packed)) - return failure(); - return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc); - }, - [&](Value flatIndex) -> FailureOr { - return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); - }, - loc); - case WholeBatchFragmentSourceKind::DirectValue: - for (const auto& [fragment, offset] : group.directFragments) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fragment, - targetClass.op, - "whole-batch direct fragment assembly tried to reuse a tensor from another materialized class"); - if (failed(localFragment)) - return failure(); - FailureOr updated = createDim0InsertSliceInClass(state, - targetClass, - loc, - *localFragment, - destination, - getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); - if (failed(updated)) - return failure(); - destination = *updated; - } - return destination; - } - - return failure(); -} - -FailureOr emitProjectedWholeBatchFragmentInsertLoop( - MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const ProjectedWholeBatchFragmentGroup& group, - llvm::function_ref(Value)> buildFragment, - Location loc) { - assert(group.fragmentType && "expected projected fragment type"); - assert(!group.offsetsByDim.empty() && "expected projected insert coordinates"); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - FailureOr fragment = buildFragment(flatIndex); - if (failed(fragment)) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - unsigned rank = group.offsetsByDim.size(); - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned dim = 0; dim < rank; ++dim) { - offsets.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.offsetsByDim[dim], flatIndex, loc)); - sizes.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.sizesByDim[dim], flatIndex, loc)); - strides.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.stridesByDim[dim], flatIndex, loc)); - } - - Value updated = - tensor::InsertSliceOp::create(state.rewriter, loc, *fragment, iterArgs.front(), offsets, sizes, strides) - .getResult(); - yielded.push_back(updated); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -std::optional getStaticProjectedPackedFragmentIndex(tensor::ExtractSliceOp extract) { - auto sourceType = dyn_cast(extract.getSource().getType()); - auto resultType = dyn_cast(extract.getResult().getType()); - if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() - || sourceType.getRank() == 0 || sourceType.getRank() != resultType.getRank()) - return std::nullopt; - - std::optional firstOffset = getConstantIndex(extract.getMixedOffsets().front()); - if (!firstOffset) - return std::nullopt; - - for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - std::optional offset = getConstantIndex(extract.getMixedOffsets()[dim]); - std::optional size = getConstantIndex(extract.getMixedSizes()[dim]); - std::optional stride = getConstantIndex(extract.getMixedStrides()[dim]); - if (!offset || !size || !stride || *stride != 1 || *size != resultType.getDimSize(dim)) - return std::nullopt; - if (dim != 0 && *offset != 0) - return std::nullopt; - } - - return *firstOffset; -} - -void appendProjectedInsertCoordinates(ProjectedWholeBatchFragmentGroup& group, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - if (group.offsetsByDim.empty()) { - size_t rank = offsets.size(); - group.offsetsByDim.resize(rank); - group.sizesByDim.resize(rank); - group.stridesByDim.resize(rank); - } - - for (size_t dim = 0; dim < offsets.size(); ++dim) { - group.offsetsByDim[dim].push_back(offsets[dim]); - group.sizesByDim[dim].push_back(sizes[dim]); - group.stridesByDim[dim].push_back(strides[dim]); - } -} - -FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType) { - auto batch = dyn_cast_or_null(key.instance.op); - auto resultTensorType = dyn_cast(resultType); - if (!batch || !resultTensorType || !resultTensorType.hasStaticShape() || resultTensorType.getRank() == 0) - return failure(); - - uint32_t batchLaneCount = static_cast(batch.getLaneCount()); - if (batchLaneCount == 0 || resultTensorType.getDimSize(0) % static_cast(batchLaneCount) != 0) - return failure(); - - WholeBatchAssemblyPlan plan; - plan.resultType = resultTensorType; - plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast(batchLaneCount); - plan.batchLaneCount = batchLaneCount; - plan.coveredLanes.assign(batchLaneCount, 0); - - if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) - return failure(); - - if (plan.coveredLaneCount == plan.batchLaneCount) - return plan; - - if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) - return failure(); - - return plan; -} - -FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - WholeBatchAssemblyPlan& plan, - Location loc) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value result = - tensor::EmptyOp::create(state.rewriter, loc, plan.resultType.getShape(), plan.resultType.getElementType()) - .getResult(); - - SmallVector groups; - if (failed(collectWholeBatchFragmentGroups(state, targetClass, plan, groups))) - return failure(); - - for (const WholeBatchFragmentGroup& group : groups) { - FailureOr updated = emitWholeBatchFragmentGroup(state, targetClass, result, group, loc); - if (failed(updated)) - return failure(); - result = *updated; - } - - state.availableValues.record(key, targetClass.id, result); - return result; -} - -// ----------------------------------------------------------------------------- -// Run materialization helpers. -// ----------------------------------------------------------------------------- - -FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType, - Location loc) { - auto batch = dyn_cast_or_null(key.instance.op); - auto resultTensorType = dyn_cast(resultType); - if (!batch || !resultTensorType || !resultTensorType.hasStaticShape()) - return failure(); - - FailureOr projection = getBatchResultProjectionInsert(batch, key.resultIndex); - if (failed(projection)) - return failure(); - - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("missing compute_batch lane argument while materializing projected whole-batch input"); - - uint32_t laneEnd = key.instance.laneStart + key.instance.laneCount; - if (laneEnd > static_cast(batch.getLaneCount())) - return failure(); - - if (targetClass.isBatch) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value result = - tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) - .getResult(); - - for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { - ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); - std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); - if (!fragment) - return failure(); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); - if (failed(offsets) || failed(sizes) || failed(strides)) - return failure(); - - SmallVector offsetAttrs; - SmallVector sizeAttrs; - SmallVector strideAttrs; - offsetAttrs.reserve(offsets->size()); - sizeAttrs.reserve(sizes->size()); - strideAttrs.reserve(strides->size()); - for (auto [offset, size, stride] : llvm::zip(*offsets, *sizes, *strides)) { - offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); - sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); - strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); - } - - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - *fragment, - targetClass.op, - "projected whole-batch assembly tried to reuse a tensor from another materialized class", - laneKey); - if (failed(localFragment)) - return failure(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - result = tensor::InsertSliceOp::create( - state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) - .getResult(); - } - - state.availableValues.record(key, targetClass.id, result); - return result; - } - - SmallVector groups; - auto getOrCreateReceiveGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { - return group.kind == ProjectedWholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - ProjectedWholeBatchFragmentGroup group; - group.kind = ProjectedWholeBatchFragmentSourceKind::DeferredReceive; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - auto getOrCreatePackedGroup = [&](Value packed, - RankedTensorType packedSourceType, - RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { - return group.kind == ProjectedWholeBatchFragmentSourceKind::PackedValue && group.fragmentType == fragmentType - && group.packed == packed && group.packedSourceType == packedSourceType; - }); - if (groupIt == groups.end()) { - ProjectedWholeBatchFragmentGroup group; - group.kind = ProjectedWholeBatchFragmentSourceKind::PackedValue; - group.fragmentType = fragmentType; - group.packed = packed; - group.packedSourceType = packedSourceType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - auto getOrCreateDirectGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { - auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { - return group.kind == ProjectedWholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; - }); - if (groupIt == groups.end()) { - ProjectedWholeBatchFragmentGroup group; - group.kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; - group.fragmentType = fragmentType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } - return *groupIt; - }; - - for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { - ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); - if (failed(offsets) || failed(sizes) || failed(strides)) - return failure(); - - bool grouped = false; - if (std::optional exact = state.availableValues.lookupExact(laneKey, targetClass.id)) { - if (auto receive = exact->getDefiningOp()) { - auto fragmentType = dyn_cast(receive.getOutput().getType()); - if (fragmentType && receive.getOutput().use_empty()) { - ProjectedWholeBatchFragmentGroup& group = getOrCreateReceiveGroup(fragmentType); - if (appendConstantChannelReceiveMessage(group.messages, receive)) { - appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); - group.redundantOps.push_back(receive.getOperation()); - grouped = true; - } - } - } - } - - if (grouped) - continue; - - std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); - if (!fragment) - return failure(); - - auto fragmentType = dyn_cast(fragment->getType()); - if (!fragmentType) - return failure(); - - if (auto extract = fragment->getDefiningOp()) { - if (std::optional packedIndex = getStaticProjectedPackedFragmentIndex(extract)) { - auto packedSourceType = dyn_cast(extract.getSource().getType()); - if (packedSourceType) { - ProjectedWholeBatchFragmentGroup& group = - getOrCreatePackedGroup(extract.getSource(), packedSourceType, fragmentType); - group.packedIndices.push_back(*packedIndex); - appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); - group.redundantOps.push_back(extract.getOperation()); - continue; - } - } - } - - ProjectedWholeBatchFragmentGroup& group = getOrCreateDirectGroup(fragmentType); - group.directFragments.push_back({*fragment, *offsets, *sizes, *strides}); - } - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value result = - tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) - .getResult(); - - for (const ProjectedWholeBatchFragmentGroup& group : groups) { - FailureOr updated = failure(); - switch (group.kind) { - case ProjectedWholeBatchFragmentSourceKind::DeferredReceive: - updated = emitProjectedWholeBatchFragmentInsertLoop( - state, - targetClass, - result, - group, - [&](Value flatIndex) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); - return SpatChannelReceiveOp::create( - state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - loc); - break; - case ProjectedWholeBatchFragmentSourceKind::PackedValue: - updated = emitProjectedWholeBatchFragmentInsertLoop( - state, - targetClass, - result, - group, - [&](Value flatIndex) -> FailureOr { - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.reserve(group.packedSourceType.getRank()); - extractSizes.reserve(group.packedSourceType.getRank()); - extractStrides.reserve(group.packedSourceType.getRank()); - extractOffsets.push_back(createIndexedOrStaticIndex( - state, targetClass.op, group.packedIndices, flatIndex, loc)); - extractSizes.push_back(state.rewriter.getIndexAttr(1)); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < group.packedSourceType.getRank(); ++dim) { - extractOffsets.push_back(state.rewriter.getIndexAttr(0)); - extractSizes.push_back(state.rewriter.getIndexAttr(group.packedSourceType.getDimSize(dim))); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - } - - FailureOr packed = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - group.packed, - targetClass.op, - "projected whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); - if (failed(packed)) - return failure(); - - return tensor::ExtractSliceOp::create( - state.rewriter, - loc, - group.fragmentType, - *packed, - extractOffsets, - extractSizes, - extractStrides) - .getResult(); - }, - loc); - break; - case ProjectedWholeBatchFragmentSourceKind::DirectValue: { - updated = result; - for (const ProjectedWholeBatchDirectFragment& fragment : group.directFragments) { - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - fragment.fragment, - targetClass.op, - "projected whole-batch assembly tried to reuse a tensor from another materialized class"); - if (failed(localFragment)) - return failure(); - - SmallVector offsetAttrs; - SmallVector sizeAttrs; - SmallVector strideAttrs; - for (auto [offset, size, stride] : llvm::zip(fragment.offsets, fragment.sizes, fragment.strides)) { - offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); - sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); - strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); - } - updated = tensor::InsertSliceOp::create( - state.rewriter, loc, *localFragment, *updated, offsetAttrs, sizeAttrs, strideAttrs) - .getResult(); - } - break; - } - } - if (failed(updated)) - return failure(); - result = *updated; - } - - for (const ProjectedWholeBatchFragmentGroup& group : groups) - for (Operation* redundantOp : group.redundantOps) - if (redundantOp && redundantOp->use_empty()) - redundantOp->erase(); - - state.availableValues.record(key, targetClass.id, result); - return result; -} - -FailureOr materializeWholeBatchInput( - MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { - FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); - if (succeeded(plan)) - return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); - - return materializeProjectedWholeBatchInputFromFragments(state, targetClass, key, resultType, loc); -} - -FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerState& state, - MaterializedClass& sourceClass, - SpatComputeBatch sourceBatch, - size_t resultIndex, - ArrayRef run, - Value packed, - RankedTensorType fragmentType, - Value originalOutput, - Location loc) { - if (!hasLiveExternalUseCached(state, originalOutput)) - return false; - if (packed.getType() == originalOutput.getType() || fragmentType == originalOutput.getType()) - return false; - - auto resultType = dyn_cast(originalOutput.getType()); - if (!resultType || !resultType.hasStaticShape()) - return false; - - FailureOr projection = getBatchResultProjectionInsert(sourceBatch, resultIndex); - if (failed(projection)) - return false; - - std::optional laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) { - sourceBatch.emitOpError("missing compute_batch lane argument while recording projected host fragments"); - return failure(); - } - - for (auto [runIndex, slot] : llvm::enumerate(run)) { - if (slot.peers.size() != 1) { - sourceClass.op->emitError("projected scalar host output publication expects scalar one-peer run slots"); - return failure(); - } - - const ComputeInstance& peer = slot.peers.front(); - if (peer.op != sourceBatch.getOperation()) { - sourceClass.op->emitError("projected scalar host output run changed source operation"); - return failure(); - } - if (peer.laneCount != 1) { - sourceClass.op->emitError("projected scalar host output publication expects one logical lane per packed slot") - << " laneStart=" << peer.laneStart << " laneCount=" << peer.laneCount; - return failure(); - } - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, peer.laneStart); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, peer.laneStart); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, peer.laneStart); - if (failed(offsets) || failed(sizes) || failed(strides)) { - sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") - << peer.laneStart; - return failure(); - } - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value fragment = getPackedSliceForRunIndex(state, sourceClass.op, packed, fragmentType, runIndex, loc); - - state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { - originalOutput, - sourceClass.id, - fragment, - fragmentType, - SmallVector(*offsets), - SmallVector(*sizes), - SmallVector(*strides), - peer.laneStart, - loc}); - } - - return true; -} - -std::optional getOriginalOutputResultIndex(Value originalOutput) { - auto result = dyn_cast(originalOutput); - if (!result) - return std::nullopt; - return static_cast(result.getResultNumber()); -} - -FailureOr recordProjectedBatchHostFragmentsFromBatchValue(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& ownerClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { - if (!sourceClass.isBatch) - return false; - if (ownerClass.isBatch) - return false; - if (!hasLiveExternalUseCached(state, originalOutput)) - return false; - if (payload.getType() == originalOutput.getType()) - return false; - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - if (!sourceBatch || sourceBatch.getNumResults() == 0) - return false; - - auto resultType = dyn_cast(originalOutput.getType()); - auto fragmentType = dyn_cast(payload.getType()); - if (!resultType || !resultType.hasStaticShape() || !fragmentType || !fragmentType.hasStaticShape()) - return false; - - std::optional resultIndex = getOriginalOutputResultIndex(originalOutput); - if (!resultIndex) - return false; - - FailureOr projection = getBatchResultProjectionInsert(sourceBatch, *resultIndex); - if (failed(projection)) - return false; - - std::optional laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) { - sourceBatch.emitOpError("missing compute_batch lane argument while recording projected batch host fragments"); - return failure(); - } - - if (keys.size() != sourceClass.cpus.size()) { - sourceClass.op->emitError("projected batch host publication expects one producer key per materialized batch lane") - << " keyCount=" << keys.size() << " laneCount=" << sourceClass.cpus.size(); - return failure(); - } - - MessageVector messages; - messages.channelIds.reserve(sourceClass.cpus.size()); - messages.sourceCoreIds.reserve(sourceClass.cpus.size()); - messages.targetCoreIds.reserve(sourceClass.cpus.size()); - - auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected batch host output target core id"); - if (failed(checkedTargetCpu)) - return failure(); - - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "projected batch host output source core id"); - if (failed(checkedSourceCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - (void) lane; - } - - if (failed(appendSend(state, sourceClass, payload, messages, loc))) - return failure(); - - for (auto [lane, key] : llvm::enumerate(keys)) { - if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != *resultIndex || key.instance.laneCount != 1) { - sourceClass.op->emitError("projected batch host publication received an incompatible producer key") - << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount - << " resultIndex=" << key.resultIndex; - return failure(); - } - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); - if (failed(offsets) || failed(sizes) || failed(strides)) { - sourceClass.op->emitError("failed to evaluate projected batch host output slice") - << " laneStart=" << key.instance.laneStart; - return failure(); - } - - state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { - originalOutput, - sourceClass.id, - payload, - fragmentType, - SmallVector(*offsets), - SmallVector(*sizes), - SmallVector(*strides), - key.instance.laneStart, - loc, - true, - messages.slice(lane, 1)}); - } - - return true; -} - -LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { - if (state.pendingProjectedHostOutputFragments.empty()) - return success(); - - DenseMap> byOutput; - for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) - byOutput[fragment.originalOutput].push_back(&fragment); - - SmallVector outputs; - outputs.reserve(byOutput.size()); - for (const auto& entry : byOutput) - outputs.push_back(entry.first); - llvm::sort(outputs, [](Value lhs, Value rhs) { - return reinterpret_cast(lhs.getAsOpaquePointer()) - < reinterpret_cast(rhs.getAsOpaquePointer()); - }); - - for (Value originalOutput : outputs) { - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) { - Operation* anchor = originalOutput.getDefiningOp() ? originalOutput.getDefiningOp() : state.func.getOperation(); - return anchor->emitError("missing host owner for projected host output fragments"); - } - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (ownerClass.isBatch) - return ownerClass.op->emitError("projected scalar host output finalization expected a scalar host owner"); - - auto resultType = dyn_cast(originalOutput.getType()); - if (!resultType || !resultType.hasStaticShape()) - return ownerClass.op->emitError("projected host output must have static ranked tensor type"); - - SmallVector& fragments = byOutput[originalOutput]; - llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, - const PendingProjectedHostOutputFragment* rhs) { - if (lhs->sourceLane != rhs->sourceLane) - return lhs->sourceLane < rhs->sourceLane; - if (lhs->sourceClass != rhs->sourceClass) - return lhs->sourceClass < rhs->sourceClass; - return std::lexicographical_compare(lhs->offsets.begin(), - lhs->offsets.end(), - rhs->offsets.begin(), - rhs->offsets.end()); - }); - - state.rewriter.setInsertionPoint(ownerClass.body->getTerminator()); - Location loc = fragments.front()->loc; - Value assembled = tensor::EmptyOp::create( - state.rewriter, loc, resultType.getShape(), resultType.getElementType()) - .getResult(); - - for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { - Value fragment = fragmentRecord->fragment; - MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; - - if (fragmentRecord->sourceClass != ownerClass.id) { - MessageVector messages; - if (fragmentRecord->sendAlreadyEmitted) { - messages = fragmentRecord->messages; - } else { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, - sourceClass.cpus.front(), - "projected host output source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, - ownerClass.cpus.front(), - "projected host output target core id"); - if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - if (failed(appendSend(state, sourceClass, fragment, messages, fragmentRecord->loc))) - return failure(); - } - fragment = appendReceive(state, ownerClass, fragmentRecord->fragmentType, messages, fragmentRecord->loc); - } else { - FailureOr localFragment = materializeTensorValueForMaterializedClassUse( - state, - ownerClass, - fragment, - ownerClass.op, - "projected host output assembly tried to reuse a non-local fragment tensor"); - if (failed(localFragment)) - return failure(); - fragment = *localFragment; - } - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(fragmentRecord->offsets.size()); - sizes.reserve(fragmentRecord->sizes.size()); - strides.reserve(fragmentRecord->strides.size()); - for (auto [offset, size, stride] : llvm::zip(fragmentRecord->offsets, - fragmentRecord->sizes, - fragmentRecord->strides)) { - offsets.push_back(state.rewriter.getIndexAttr(offset)); - sizes.push_back(state.rewriter.getIndexAttr(size)); - strides.push_back(state.rewriter.getIndexAttr(stride)); - } - - state.rewriter.setInsertionPoint(ownerClass.body->getTerminator()); - assembled = tensor::InsertSliceOp::create( - state.rewriter, fragmentRecord->loc, fragment, assembled, offsets, sizes, strides) - .getResult(); - } - - if (failed(setHostOutputValue(state, ownerClass, originalOutput, assembled))) - return failure(); - } - - return success(); -} - -FailureOr resolveInputValue(MaterializerState& state, - MaterializedClass& targetClass, - Value input, - const ComputeInstance& consumerInstance, - CloneIndexingContext indexing) { - auto rejectNonLocalResolvedValue = [&](Value resolved) -> FailureOr { - if (!isTensorValueDefinedInDifferentMaterializedClass(resolved, targetClass)) - return resolved; - - std::optional producer = getInputRequestProducerKey(input, consumerInstance); - emitNonLocalMaterializedClassValueDiagnostic(consumerInstance.op, - targetClass, - "input resolution tried to reuse a tensor from another materialized class", - resolved, - producer); - return failure(); - }; - - if (isConstantLike(input)) - return input; - - if (std::optional producer = getInputRequestProducerKey(input, consumerInstance)) { - if (indexing.runSlotIndex) { - if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { - FailureOr received = materializeIndexedBatchRunReceive( - state, targetClass, *indexedRun, *indexing.runSlotIndex, consumerInstance.op->getLoc()); - if (failed(received)) - return failure(); - return rejectNonLocalResolvedValue(*received); - } - } - - if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) - return rejectNonLocalResolvedValue(*value); - - - if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { - size_t laneCount = targetClass.cpus.size(); - for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { - if (!llvm::is_contained(slot.keys, *producer)) - continue; - - MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); - Value received = - appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); - for (ProducerKey slotKey : slot.keys) - state.availableValues.record(slotKey, targetClass.id, received); - return rejectNonLocalResolvedValue(received); - } - } - - if (isWholeBatchProducerKey(*producer)) { - FailureOr wholeBatch = - materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); - if (failed(wholeBatch)) - consumerInstance.op->emitError("failed to materialize whole-batch input") - << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart - << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - if (failed(wholeBatch)) - return failure(); - return rejectNonLocalResolvedValue(*wholeBatch); - } - - consumerInstance.op->emitError("failed to resolve producer value") - << " from op '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart - << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - return failure(); - } - - if (isTensorValueDefinedInDifferentMaterializedClass(input, targetClass)) { - emitNonLocalMaterializedClassValueDiagnostic( - consumerInstance.op, - targetClass, - "input resolution tried to append a tensor from another materialized class as a normal input", - input); - return failure(); - } - - return appendInput(state, targetClass, input); -} - -bool hasProjectedInputReplacement(MaterializerState& state, - SpatComputeBatch batch, - unsigned inputIndex, - ClassId classId) { - std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); - if (!match) - return false; - - auto replacementIt = state.projectedExtractReplacements.find(match->extract.getOperation()); - if (replacementIt == state.projectedExtractReplacements.end()) - return false; - - return replacementIt->second.find(classId) != replacementIt->second.end(); -} - -void mapWeights(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper) { - Operation* op = instance.op; - if (auto compute = dyn_cast(op)) { - for (auto [index, weight] : llvm::enumerate(compute.getWeights())) { - auto weightArg = compute.getWeightArgument(index); - assert(weightArg && "expected compute weight block argument"); - mapper.map(*weightArg, appendWeight(state, targetClass, weight)); - } - return; - } - - auto batch = cast(op); - for (auto [index, weight] : llvm::enumerate(batch.getWeights())) { - auto weightArg = batch.getWeightArgument(index); - assert(weightArg && "expected compute_batch weight block argument"); - mapper.map(*weightArg, appendWeight(state, targetClass, weight)); - } -} - -LogicalResult mapInputs(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper, - CloneIndexingContext indexing) { - auto mapResolvedInput = [&](Value resolved) -> FailureOr { - return materializeTensorValueForMaterializedClassUse( - state, - targetClass, - resolved, - targetClass.op, - "input mapping tried to reuse a tensor from another materialized class"); - }; - - Operation* op = instance.op; - if (auto compute = dyn_cast(op)) { - for (auto [index, input] : llvm::enumerate(compute.getInputs())) { - FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); - if (failed(mapped)) { - std::optional producer = getInputRequestProducerKey(input, instance); - auto diagnostic = compute.emitOpError("failed to resolve materialized compute input") << " #" << index; - if (producer) { - diagnostic << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart - << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - } - return failure(); - } - auto inputArg = compute.getInputArgument(index); - if (!inputArg) - return compute.emitOpError("expected compute input block argument while materializing inputs"); - FailureOr remapped = mapResolvedInput(*mapped); - if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic(compute, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); - return failure(); - } - mapper.map(*inputArg, *remapped); - } - return success(); - } - - auto batch = cast(op); - for (auto [index, input] : llvm::enumerate(batch.getInputs())) { - if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) - continue; - - FailureOr mapped = failure(); - if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input); - wholeBatchProducer && !canUseProjectedLaneInput(state, batch, static_cast(index), input, instance)) { - mapped = materializeWholeBatchInput( - state, targetClass, *wholeBatchProducer, input.getType(), batch.getOperation()->getLoc()); - if (failed(mapped)) - return batch.emitOpError("failed to materialize whole-batch compute_batch input") - << " #" << index << " from '" << wholeBatchProducer->instance.op->getName() - << "' laneStart=" << wholeBatchProducer->instance.laneStart - << " laneCount=" << wholeBatchProducer->instance.laneCount - << " resultIndex=" << wholeBatchProducer->resultIndex; - } else { - mapped = resolveInputValue(state, targetClass, input, instance, indexing); - if (failed(mapped)) - return batch.emitOpError("failed to resolve materialized compute_batch input"); - } - - auto inputArg = batch.getInputArgument(index); - if (!inputArg) - return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); - FailureOr remapped = mapResolvedInput(*mapped); - if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic(batch, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); - return failure(); - } - mapper.map(*inputArg, *remapped); - } - return success(); -} - -SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMapping& mapper) { - SmallVector outputs(batch.getNumResults(), Value {}); - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return outputs; - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return outputs; - unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (resultIndex >= outputs.size()) - continue; - outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource()); - } - - return outputs; -} - -SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { - SmallVector types(batch.getNumResults(), Type {}); - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return types; - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return types; - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (resultIndex >= types.size()) - continue; - - types[resultIndex] = insert.getSource().getType(); - } - - return types; -} - -SmallVector& getBatchOutputFragmentTypesCached(MaterializerState& state, SpatComputeBatch batch) { - auto [it, inserted] = state.batchOutputFragmentTypesCache.try_emplace(batch.getOperation(), SmallVector {}); - if (inserted) - it->second = collectBatchOutputFragmentTypes(batch); - return it->second; -} - -ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance) { - auto [it, inserted] = state.computeInstanceOutputsCache.try_emplace(instance, SmallVector {}); - if (inserted) - it->second = getComputeInstanceOutputValues(instance); - return it->second; -} - -std::optional lookupProjectedExtractReplacement(MaterializerState& state, - MaterializedClass& targetClass, - tensor::ExtractSliceOp extract) { - auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation()); - if (replacementIt == state.projectedExtractReplacements.end()) - return std::nullopt; - - auto classIt = replacementIt->second.find(targetClass.id); - if (classIt == replacementIt->second.end()) - return std::nullopt; - - return classIt->second; -} - -LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, - MaterializedClass& targetClass, - Operation& originalOp, - Operation& clonedOp, - CloneIndexingContext indexing, - IRMapping& mapper) { - if (auto originalExtract = dyn_cast(&originalOp)) { - if (std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, originalExtract)) { - auto clonedExtract = dyn_cast(&clonedOp); - if (!clonedExtract) - return targetClass.op->emitError("projected replacement lost extract structure during cloning"); - - state.rewriter.setInsertionPoint(clonedExtract); - FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex, &mapper); - if (failed(projected)) - return failure(); - - clonedExtract.getResult().replaceAllUsesWith(*projected); - state.rewriter.eraseOp(clonedExtract); - return success(); - } - } - - if (originalOp.getNumRegions() != clonedOp.getNumRegions()) - return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned regions"); - - for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { - if (std::distance(originalRegion.begin(), originalRegion.end()) - != std::distance(clonedRegion.begin(), clonedRegion.end())) - return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned blocks"); - - for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { - auto originalIt = originalBlock.begin(); - auto clonedIt = clonedBlock.begin(); - while (originalIt != originalBlock.end() && clonedIt != clonedBlock.end()) { - Operation& originalNestedOp = *originalIt++; - Operation* currentClonedOp = &*clonedIt++; - if (failed(applyProjectedExtractReplacementsInClonedOp( - state, targetClass, originalNestedOp, *currentClonedOp, indexing, mapper))) - return failure(); - } - if (originalIt != originalBlock.end() || clonedIt != clonedBlock.end()) - return targetClass.op->emitError("projected replacement traversal found mismatched cloned operations"); - } - } - - return success(); -} - -LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& clonedOp, IRMapping& mapper) { - if (originalOp.getNumRegions() != clonedOp.getNumRegions()) - return clonedOp.emitError("cloned operation has a different number of regions than the source operation"); - - for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { - if (std::distance(originalRegion.begin(), originalRegion.end()) - != std::distance(clonedRegion.begin(), clonedRegion.end())) - return clonedOp.emitError("cloned operation has a different number of blocks than the source operation"); - - for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { - if (originalBlock.getNumArguments() != clonedBlock.getNumArguments()) - return clonedOp.emitError("cloned operation block has a different number of arguments than the source block"); - - for (auto [originalArg, clonedArg] : llvm::zip(originalBlock.getArguments(), clonedBlock.getArguments())) - if (!mapper.contains(originalArg)) - mapper.map(originalArg, clonedArg); - - if (std::distance(originalBlock.begin(), originalBlock.end()) != std::distance(clonedBlock.begin(), clonedBlock.end())) - return clonedOp.emitError("cloned operation block has a different number of operations than the source block"); - - auto originalIt = originalBlock.begin(); - auto clonedIt = clonedBlock.begin(); - while (originalIt != originalBlock.end()) { - Operation& originalNestedOp = *originalIt++; - Operation& clonedNestedOp = *clonedIt++; - if (failed(mapClonedRegionBlockArguments(originalNestedOp, clonedNestedOp, mapper))) - return failure(); - } - } - } - - return success(); -} - -LogicalResult cloneComputeTemplateBody(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper, - CloneIndexingContext indexing) { - Block& sourceBlock = getComputeInstanceTemplateBlock(instance); - for (Operation& op : sourceBlock.without_terminator()) { - if (auto extract = dyn_cast(&op)) { - if (std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, extract)) { - FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, extract, *replacement, indexing.projectionSlotIndex, &mapper); - if (failed(projected)) - return failure(); - - mapper.map(extract.getResult(), *projected); - continue; - } - } - - for (Value operand : op.getOperands()) { - if (mapper.contains(operand)) - continue; - - FailureOr localized = localizeMaterializedClassOperand( - state, - targetClass, - operand, - &op, - "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", - "cloneComputeTemplateBody produced an unsupported external non-tensor operand", - &mapper); - if (failed(localized)) - return failure(); - if (*localized != operand) - mapper.map(operand, *localized); - } - - Operation* cloned = state.rewriter.clone(op, mapper); - if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) - return failure(); - if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) - return failure(); - if (op.getNumRegions() != 0 - && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing, mapper))) - return failure(); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(oldResult, newResult); - } - - return success(); -} - -FailureOr materializeProjectedExtractReplacement(MaterializerState& state, - MaterializedClass& targetClass, - tensor::ExtractSliceOp extract, - const ProjectedExtractReplacement& replacement, - std::optional projectionSlotIndex, - IRMapping* mapper) { - if (failed(verifyProjectedFragmentLayout(targetClass.op, replacement.layout))) - return failure(); - - FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( - state, - targetClass, - replacement.payload, - targetClass.op, - "projected extract replacement tried to reuse a tensor from another materialized class", - std::nullopt, - mapper); - if (failed(localizedPayload)) - return failure(); - Value payload = *localizedPayload; - - if (replacement.layout.payloadFragmentCount == 1) - return payload; - - if (replacement.layout.payloadFragmentCount < replacement.layout.fragmentsPerLogicalSlot) - return targetClass.op->emitError("projected replacement payload is smaller than one logical slot"); - - Value intraSlotFragmentIndex = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - const auto linearizeProjectedLoopIndices = [&]() -> FailureOr { - if (replacement.layout.loopTripCounts.empty()) - return intraSlotFragmentIndex; - - SmallVector surroundingLoops; - for (Operation* current = extract->getParentOp(); current; current = current->getParentOp()) { - if (auto loop = dyn_cast(current)) - surroundingLoops.push_back(loop); - if (current == targetClass.op) - break; - } - std::reverse(surroundingLoops.begin(), surroundingLoops.end()); - - if (surroundingLoops.size() != replacement.layout.loopTripCounts.size()) - return targetClass.op->emitError("projected replacement loop structure does not match the collected descriptor"); - - Value linearizedIndex = intraSlotFragmentIndex; - for (auto [index, loop] : llvm::enumerate(surroundingLoops)) { - FailureOr localizedIv = - rematerializeIndexValueInClass(state, targetClass, loop.getInductionVar(), extract.getLoc(), mapper); - if (failed(localizedIv)) - return failure(); - Value iv = *localizedIv; - Value lowerBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); - Value tripCount = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]); - - Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult(); - if (replacement.layout.loopSteps[index] != 1) - normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult(); - linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult(); - linearizedIndex = - arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult(); - } - return linearizedIndex; - }; - - FailureOr linearizedIndex = linearizeProjectedLoopIndices(); - if (failed(linearizedIndex)) - return failure(); - intraSlotFragmentIndex = *linearizedIndex; - - const auto computeProjectedPayloadFragmentIndex = [&]() -> FailureOr { - if (replacement.layout.payloadFragmentCount == replacement.layout.fragmentsPerLogicalSlot) { - if (replacement.layout.loopTripCounts.empty() && replacement.layout.fragmentsPerLogicalSlot != 1) - return targetClass.op->emitError("projected replacement is missing loop metadata for packed logical slot"); - return intraSlotFragmentIndex; - } - - if (!projectionSlotIndex) - return targetClass.op->emitError("packed projected extract replacement requires a fragment slot index"); - - FailureOr localProjectionSlotIndex = - rematerializeIndexValueInClass(state, targetClass, *projectionSlotIndex, extract.getLoc(), mapper); - if (failed(localProjectionSlotIndex)) - return failure(); - - Value fragmentsPerLogicalSlot = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); - Value base = - arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot) - .getResult(); - return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); - }; - - FailureOr packedFragmentIndex = computeProjectedPayloadFragmentIndex(); - if (failed(packedFragmentIndex)) - return failure(); - - FailureOr packedOffset = scaleIndexByDim0SizeInClass( - state, targetClass, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); - if (failed(packedOffset)) - return failure(); - return createDim0ExtractSliceInClass( - state, targetClass, extract.getLoc(), payload, *packedOffset, replacement.layout.fragmentType.getDimSize(0)); -} - -FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, - MaterializedClass& targetClass, - IndexedBatchRunValue& run, - Value runSlotIndex, - Location loc) { - if (!targetClass.isBatch) - return targetClass.op->emitError("indexed batch run receive requires a batch target class"); - if (failed(run.messages.verify(targetClass.op))) - return failure(); - - Value flatIndex = createBatchRunFlatIndex(state, targetClass, runSlotIndex, loc); - std::optional preferredPeriod = static_cast(targetClass.cpus.size()); - Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); -} - -LogicalResult localizeCapturesInOperationTree(MaterializerState& state, - MaterializedClass& targetClass, - Operation& root, - StringRef tensorContext, - StringRef genericContext, - IRMapping* mapper = nullptr) { - WalkResult walkResult = root.walk([&](Operation* nestedOp) -> WalkResult { - for (OpOperand& operand : nestedOp->getOpOperands()) { - Value current = operand.get(); - if (isValueLegalInMaterializedClassBody(current, targetClass)) - continue; - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(nestedOp); - FailureOr localized = - localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); - if (failed(localized)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand"); - diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() - << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() - << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) - << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) - << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; - diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; - attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR"); - attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands"); - attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain"); - attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); - attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return WalkResult::interrupt(); - } - operand.set(*localized); - } - return WalkResult::advance(); - }); - - return walkResult.wasInterrupted() ? failure() : success(); -} - -LogicalResult localizeCapturesInClonedOp(MaterializerState& state, - MaterializedClass& targetClass, - Operation& clonedOp, - IRMapping* mapper) { - return localizeCapturesInOperationTree( - state, - targetClass, - clonedOp, - "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", - "cloneComputeTemplateBody produced an unsupported external non-tensor operand", - mapper); -} - -LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass) { - SmallVector bodyOps; - for (Operation& op : *targetClass.body) - op.walk([&](Operation* nestedOp) { bodyOps.push_back(nestedOp); }); - - for (Operation* nestedOp : bodyOps) { - if (nestedOp->getBlock() == nullptr) - continue; - for (OpOperand& operand : nestedOp->getOpOperands()) { - Value current = operand.get(); - if (isValueLegalInMaterializedClassBody(current, targetClass)) - continue; - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(nestedOp); - FailureOr localized = localizeMaterializedClassOperand( - state, - targetClass, - current, - nestedOp, - "final scheduled body capture localization tried to reuse a tensor from another materialized class", - "final scheduled body capture localization found an unsupported external non-tensor operand"); - if (failed(localized)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand"); - diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() - << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() - << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) - << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) - << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; - diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; - attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); - attachMaterializedClassBodySummary(diagnostic, targetClass); - return failure(); - } - operand.set(*localized); - } - } - - return success(); -} - -FailureOr> cloneInstanceBody(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef peers, - CloneIndexingContext indexing) { - assert(!peers.empty() && "expected at least one peer instance"); - const ComputeInstance& instance = peers.front(); - Operation* sourceOp = instance.op; - Location loc = sourceOp->getLoc(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - IRMapping mapper; - if (auto batch = dyn_cast(sourceOp)) { - for (const ComputeInstance& peer : peers) { - if (peer.op != sourceOp) { - sourceOp->emitError("equivalence class slot contains different source compute_batch operations"); - return failure(); - } - } - auto laneArg = batch.getLaneArgument(); - if (!laneArg) { - sourceOp->emitError("expected source compute_batch lane block argument"); - return failure(); - } - mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc)); - } - - OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); - - mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) - return failure(); - - state.rewriter.restoreInsertionPoint(cloneInsertionPoint); - if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) - return failure(); - - if (auto compute = dyn_cast(sourceOp)) { - Block& sourceBlock = getComputeInstanceTemplateBlock(instance); - auto yield = dyn_cast_or_null(sourceBlock.getTerminator()); - if (!yield) { - compute.emitOpError("expected spat.yield terminator while materializing compute"); - return failure(); - } - - SmallVector outputs; - outputs.reserve(yield.getNumOperands()); - for (Value yielded : yield.getOutputs()) - outputs.push_back(mapper.lookupOrDefault(yielded)); - return outputs; - } - - auto batch = cast(sourceOp); - if (batch.getNumResults() == 0) - return SmallVector {}; - - SmallVector outputs = collectMappedBatchOutputs(batch, mapper); - for (Value output : outputs) - if (!output) { - batch.emitOpError("failed to recover yielded per-lane value for compute_batch result"); - return failure(); - } - return outputs; -} - -bool sameDestinationClasses(ArrayRef lhs, ArrayRef rhs) { - if (lhs.size() != rhs.size()) - return false; - for (auto [lhsClass, rhsClass] : llvm::zip(lhs, rhs)) - if (lhsClass != rhsClass) - return false; - return true; -} - -SmallVector -collectDestinationClassesForRun(MaterializerState& state, ArrayRef run, size_t resultIndex) { - SmallVector destinations; - - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) { - ProducerKey key {peer, resultIndex}; - for (ClassId destinationClass : getDestinationClasses(state, key)) - if (!llvm::is_contained(destinations, destinationClass)) - destinations.push_back(destinationClass); - } - } - - llvm::sort(destinations); - return destinations; -} - -SmallVector groupBatchRunOutputsByDestination(MaterializerState& state, - ArrayRef run) { - assert(!run.empty() && "expected non-empty materialization run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - - SmallVector groups; - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); - - for (auto [resultIndex, output] : llvm::enumerate(outputs)) { - SmallVector destinations = collectDestinationClassesForRun(state, run, resultIndex); - - auto existingGroup = llvm::find_if(groups, [&](const OutputDestinationGroup& group) { - return sameDestinationClasses(group.destinationClasses, destinations); - }); - - if (existingGroup != groups.end()) { - existingGroup->resultIndices.push_back(resultIndex); - continue; - } - - OutputDestinationGroup group; - group.resultIndices.push_back(resultIndex); - group.destinationClasses = std::move(destinations); - groups.push_back(std::move(group)); - } - - return groups; -} - -FailureOr getPackedRunTensorType(Type elementType, size_t runSize) { - auto tensorType = dyn_cast(elementType); - if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) - return failure(); - - SmallVector shape(tensorType.getShape()); - shape[0] *= static_cast(runSize); - return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); -} - -LogicalResult registerDeferredLocalPackedRunValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef keys, - Type fragmentType, - Location loc) { - if (keys.empty()) - return success(); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return materializedClass.op->emitError("deferred local packed run expects static ranked fragment type"); - - Operation* sourceOp = keys.front().instance.op; - size_t resultIndex = keys.front().resultIndex; - - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != resultIndex) - return materializedClass.op->emitError("deferred local packed run expects one producer result"); - - if (key.instance.laneCount != 1) - return materializedClass.op->emitError("deferred local packed run expects one lane per fragment"); - } - - PackedScalarRunValue packedRun; - packedRun.targetClass = materializedClass.id; - packedRun.sourceOp = sourceOp; - packedRun.resultIndex = resultIndex; - packedRun.kind = PackedScalarRunKind::DeferredLocalCompute; - packedRun.fragmentType = rankedFragmentType; - - packedRun.slots.reserve(keys.size()); - for (ProducerKey key : keys) { - PackedScalarRunSlot slot; - slot.keys.push_back(key); - packedRun.slots.push_back(std::move(slot)); - } - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -LogicalResult registerPackedRunValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef keys, - Value packed, - Type fragmentType, - Location loc) { - if (keys.empty()) - return success(); - - FailureOr expectedPackedType = getPackedRunTensorType(fragmentType, keys.size()); - if (failed(expectedPackedType)) - return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); - - if (packed.getType() != *expectedPackedType) - return materializedClass.op->emitError("packed run value has unexpected tensor type"); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); - - Operation* sourceOp = keys.front().instance.op; - size_t resultIndex = keys.front().resultIndex; - - for (ProducerKey key : keys) { - if (key.instance.op != sourceOp || key.resultIndex != resultIndex) - return materializedClass.op->emitError("packed run registration expects one producer result"); - if (key.instance.laneCount != 1) - return materializedClass.op->emitError("packed run registration expects one lane per packed fragment"); - } - - if (std::optional contiguousKey = getContiguousProducerRangeForKeys(keys)) { - state.availableValues.record(*contiguousKey, materializedClass.id, packed); - return success(); - } - - PackedScalarRunValue packedRun; - packedRun.targetClass = materializedClass.id; - packedRun.sourceOp = sourceOp; - packedRun.resultIndex = resultIndex; - packedRun.packed = packed; - packedRun.kind = PackedScalarRunKind::Materialized; - packedRun.fragmentType = rankedFragmentType; - - packedRun.slots.reserve(keys.size()); - for (ProducerKey key : keys) { - PackedScalarRunSlot slot; - slot.keys.push_back(key); - packedRun.slots.push_back(std::move(slot)); - } - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -LogicalResult emitPackedRunFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef destinationClasses, - ArrayRef keys, - Value packed, - Type fragmentType, - Location loc) { - assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); - - auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, packed); - if (failed(fanoutPlan)) - return failure(); - if (failed(emitScalarSourceFanoutSends(state, sourceClass, packed, *fanoutPlan, loc))) - return failure(); - - for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { - MaterializedClass& targetClass = state.classes[plan.targetClass]; - - Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); - - if (plan.projectedExtractOp) { - state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = - ProjectedExtractReplacement {received, plan.projectedLayout}; - continue; - } - - if (failed(registerPackedRunValue(state, targetClass, keys, received, fragmentType, loc))) - return failure(); - } - - return success(); -} - -FailureOr> cloneBatchBodyForLane(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - Value laneValue, - ArrayRef resultIndices, - CloneIndexingContext indexing) { - auto batch = dyn_cast(instance.op); - if (!batch) - return failure(); - - IRMapping mapper; - auto sourceLaneArg = batch.getLaneArgument(); - if (!sourceLaneArg) - return batch.emitOpError("expected source compute_batch lane block argument"); - - mapper.map(*sourceLaneArg, laneValue); - - OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); - - mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) - return failure(); - - state.rewriter.restoreInsertionPoint(cloneInsertionPoint); - if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) - return failure(); - - SmallVector allOutputs = collectMappedBatchOutputs(batch, mapper); - if (allOutputs.empty() && !resultIndices.empty()) - return batch.emitOpError("failed to recover source compute_batch outputs"); - - SmallVector selectedOutputs; - selectedOutputs.reserve(resultIndices.size()); - for (size_t resultIndex : resultIndices) { - if (resultIndex >= allOutputs.size() || !allOutputs[resultIndex]) - return batch.emitOpError("failed to recover selected compute_batch output"); - selectedOutputs.push_back(allOutputs[resultIndex]); - } - - return selectedOutputs; -} - -FailureOr> materializeBatchOutputGroupLoop(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - const OutputDestinationGroup& group) { - assert(!run.empty() && "expected non-empty batch run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - - Operation* sourceOp = run.front().peers.front().op; - Location loc = sourceOp->getLoc(); - - if (run.size() == 1) { - if (run.front().peers.size() != 1) - return sourceOp->emitError("scalar batch output loop expects exactly one peer in singleton slot"); - - const ComputeInstance& item = run.front().peers.front(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart); - return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {}); - } - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - auto sourceBatch = cast(sourceOp); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); - SmallVector initValues; - for (size_t resultIndex : group.resultIndices) { - if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); - - Type fragmentType = fragmentTypes[resultIndex]; - FailureOr packedType = getPackedRunTensorType(fragmentType, run.size()); - if (failed(packedType)) - return sourceBatch.emitOpError("cannot materialize packed batch run for non-static ranked output"); - - initValues.push_back( - tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); - } - - SmallVector logicalLanes; - logicalLanes.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { - if (slot.peers.size() != 1) - return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); - - const ComputeInstance& item = slot.peers.front(); - if (item.op != sourceOp) - return sourceOp->emitError("materialization run contains different source operations"); - - logicalLanes.push_back(item.laneStart); - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange(initValues), - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - run.front().peers.front(), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); - if (failed(produced)) - return failure(); - - yielded.reserve(produced->size()); - for (auto [outputIndex, output] : llvm::enumerate(*produced)) { - auto fragmentType = cast(output.getType()); - Value acc = iterArgs[outputIndex]; - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); - yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); - } - return success(); - }); - if (failed(loop)) - return failure(); - - SmallVector results; - results.reserve(loop->results.size()); - for (Value result : loop->results) - results.push_back(result); - return results; -} - -SmallVector getMaterializationRunSlotOutputKeys(const MaterializationRunSlot& slot, - size_t resultIndex) { - SmallVector keys; - keys.reserve(slot.peers.size()); - for (const ComputeInstance& peer : slot.peers) - keys.push_back({peer, resultIndex}); - return keys; -} - -FailureOr> -getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId logicalSlot) { - if (targetClass.isBatch) - return getPeerLogicalInstances(state, targetClass, logicalSlot); - - auto streamIt = state.logicalInstancesByCpu.find(targetClass.cpus.front()); - if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) - return failure(); - - return SmallVector {streamIt->second[logicalSlot]}; -} - -FailureOr collectBatchMaterializationRun(MaterializerState& state, - MaterializedClass& targetClass, - SlotId startSlot, - Operation* sourceOp) { - MaterializationRun run; - - for (SlotId slot = startSlot;; ++slot) { - ClassSlotKey classSlot {targetClass.id, slot}; - if (state.materializedLogicalSlots.contains(classSlot)) - break; - - FailureOr> peers = getMaterializationRunSlotPeers(state, targetClass, slot); - if (failed(peers) || peers->empty()) - break; - - bool validSlot = true; - for (const ComputeInstance& peer : *peers) { - if (peer.op != sourceOp || !isa(peer.op)) { - validSlot = false; - break; - } - } - - if (!validSlot) - break; - - MaterializationRunSlot runSlot; - runSlot.peers = std::move(*peers); - run.push_back(std::move(runSlot)); - } - - if (run.empty()) - return failure(); - - return run; -} - -SmallVector getMaterializationRunOutputKeys(ArrayRef run, size_t resultIndex) { - SmallVector keys; - for (const MaterializationRunSlot& slot : run) - llvm::append_range(keys, getMaterializationRunSlotOutputKeys(slot, resultIndex)); - return keys; -} - -ArrayRef getFirstMaterializationRunOriginalOutputs(MaterializerState& state, - ArrayRef run) { - assert(!run.empty() && "expected non-empty materialization run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - return getComputeInstanceOutputValuesCached(state, run.front().peers.front()); -} - -Operation* getMaterializationRunSourceOp(ArrayRef run) { - assert(!run.empty() && "expected non-empty materialization run"); - assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - return run.front().peers.front().op; -} - -Location getMaterializationRunLoc(ArrayRef run) { - return getMaterializationRunSourceOp(run)->getLoc(); -} - -bool hasMaterializationRunResultLiveExternalUse(MaterializerState& state, - ArrayRef run, - size_t resultIndex) { - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) { - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, peer); - if (resultIndex >= outputs.size()) - return true; - - if (hasLiveExternalUseCached(state, outputs[resultIndex])) - return true; - } - } - - return false; -} - -bool hasMaterializationRunGroupLiveExternalUse(MaterializerState& state, - ArrayRef run, - const OutputDestinationGroup& group) { - for (size_t resultIndex : group.resultIndices) - if (hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) - return true; - - return false; -} - -bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId); - -bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, - ClassId classId, - ArrayRef run, - const OutputDestinationGroup& group) { - for (size_t resultIndex : group.resultIndices) { - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) - if (hasSameClassConsumer(state, {peer, resultIndex}, classId)) - return true; - } - } - - return false; -} - -void markMaterializationRunSlots(MaterializerState& state, - ClassId classId, - SlotId startSlot, - ArrayRef run) { - for (auto slotIndex : llvm::seq(0, run.size())) - state.materializedLogicalSlots.insert({classId, startSlot + static_cast(slotIndex)}); -} - -LogicalResult materializeScalarBatchRun(MaterializerState& state, - MaterializedClass& targetClass, - SlotId startSlot, - ArrayRef run) { - assert(!targetClass.isBatch && "scalar batch run materialization expects scalar target class"); - assert(!run.empty() && "expected non-empty batch run"); - - markMaterializationRunSlots(state, targetClass.id, startSlot, run); - - SmallVector groups = groupBatchRunOutputsByDestination(state, run); - ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); - - auto sourceBatch = cast(getMaterializationRunSourceOp(run)); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); - Location loc = getMaterializationRunLoc(run); - - for (const OutputDestinationGroup& group : groups) { - if (run.size() > 1 && group.destinationClasses.empty() - && !hasMaterializationRunGroupLiveExternalUse(state, run, group) - && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group)) { - for (size_t resultIndex : group.resultIndices) { - if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); - - SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); - if (failed(registerDeferredLocalPackedRunValue(state, targetClass, keys, fragmentTypes[resultIndex], loc))) - return failure(); - } - - continue; - } - - FailureOr> packedOutputs = materializeBatchOutputGroupLoop(state, targetClass, run, group); - if (failed(packedOutputs)) - return failure(); - - for (auto [groupOutputIndex, resultIndex] : llvm::enumerate(group.resultIndices)) { - Value packed = (*packedOutputs)[groupOutputIndex]; - if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); - - Type fragmentType = fragmentTypes[resultIndex]; - SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); - - auto rankedFragmentType = cast(fragmentType); - Value representativeOriginalOutput = firstOriginalOutputs[resultIndex]; - FailureOr recordedProjectedHostFragments = recordProjectedScalarHostFragmentsFromPackedRun( - state, targetClass, sourceBatch, resultIndex, run, packed, rankedFragmentType, representativeOriginalOutput, loc); - if (failed(recordedProjectedHostFragments)) - return failure(); - - if (run.size() == 1) { - if (*recordedProjectedHostFragments) { - if (failed(emitScalarSourceCommunication(state, targetClass, keys, packed, loc))) - return failure(); - continue; - } - - if (failed(emitOutputFanout(state, targetClass, keys, packed, representativeOriginalOutput, loc))) - return failure(); - continue; - } - - if (failed(emitPackedRunFanout(state, targetClass, group.destinationClasses, keys, packed, fragmentType, loc))) - return failure(); - - if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) - return failure(); - - if (*recordedProjectedHostFragments) - continue; - - for (auto [runIndex, slot] : llvm::enumerate(run)) { - assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); - - ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, slot.peers.front()); - Value originalOutput = originalOutputs[resultIndex]; - - if (!hasLiveExternalUseCached(state, originalOutput)) - continue; - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value fragment = getPackedSliceForRunIndex(state, targetClass.op, packed, rankedFragmentType, runIndex, loc); - - if (failed(emitHostCommunication(state, targetClass, fragment, originalOutput))) - return failure(); - } - } - } - - return success(); -} - -bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { - SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId}; - auto it = state.sameClassConsumerIndex.find(lookupKey); - if (it == state.sameClassConsumerIndex.end()) - return false; - - for (ProducerKey existing : it->second) - if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing)) - return true; - return false; -} - -bool canCompactBatchClassRun(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run) { - if (run.size() < 2) - return false; - if (run.front().peers.empty()) - return false; - - ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); - - for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { - (void) ignored; - for (const MaterializationRunSlot& slot : run) { - if (slot.peers.empty()) - return false; - - for (const ComputeInstance& peer : slot.peers) { - ArrayRef peerOutputs = getComputeInstanceOutputValuesCached(state, peer); - if (resultIndex >= peerOutputs.size()) - return false; - - Value originalOutput = peerOutputs[resultIndex]; - if (hasLiveExternalUseCached(state, originalOutput)) - return false; - - ProducerKey key {peer, resultIndex}; - if (hasSameClassConsumer(state, key, targetClass.id)) - return false; - } - } - } - - return true; -} - -Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { - auto batch = cast(targetClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected materialized compute_batch lane argument"); - - MLIRContext* context = state.func.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - AffineExpr d1 = getAffineDimExpr(1, context); - - int64_t laneCount = static_cast(targetClass.cpus.size()); - AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); - return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); -} - -Value createBatchClassRunSourceLane(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - Value slotIndex, - Location loc) { - SmallVector sourceLanes; - sourceLanes.reserve(run.size() * targetClass.cpus.size()); - - for (auto [runSlotIndex, slot] : llvm::enumerate(run)) { - (void) runSlotIndex; - assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); - for (const ComputeInstance& peer : slot.peers) - sourceLanes.push_back(peer.laneStart); - } - - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - return createIndexedIndexValue(state, - targetClass.op, - sourceLanes, - flatIndex, - loc, - static_cast(targetClass.cpus.size()), - /*allowExhaustiveTiledSearch=*/false); -} - -LogicalResult buildBatchRunSendPlans(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef run, - const OutputDestinationGroup& group, - SmallVectorImpl& plans) { - assert(sourceClass.isBatch && "batch run send planning expects a materialized batch source"); - - for (size_t resultIndex : group.resultIndices) { - for (ClassId destinationClass : group.destinationClasses) { - if (destinationClass == sourceClass.id) - return sourceClass.op->emitError("batch-target run compaction cannot handle same-class consumers"); - - MaterializedClass& targetClass = state.classes[destinationClass]; - - if (targetClass.isBatch && targetClass.cpus.size() != sourceClass.cpus.size()) - return sourceClass.op->emitError( - "cannot compact batch run communication between batch classes of different sizes"); - - BatchRunSendPlan plan; - plan.resultIndex = resultIndex; - plan.destinationClass = destinationClass; - - size_t messageCount = run.size() * sourceClass.cpus.size(); - plan.messages.channelIds.reserve(messageCount); - plan.messages.sourceCoreIds.reserve(messageCount); - plan.messages.targetCoreIds.reserve(messageCount); - - for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); - if (failed(checkedSourceCpu)) - return failure(); - auto checkedTargetCpu = - getCheckedCoreId(targetClass.op, - targetClass.isBatch ? targetClass.cpus[lane] : targetClass.cpus.front(), - "batch run target core id"); - if (failed(checkedTargetCpu)) - return failure(); - plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - } - (void) slotIndex; - } - - plans.push_back(std::move(plan)); - } - } - - return success(); -} - -void appendBatchRunSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const BatchRunSendPlan& plan, - Value flatIndex, - Location loc) { - assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); - - std::optional preferredPeriod = static_cast(sourceClass.cpus.size()); - Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); - - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); -} - -LogicalResult appendPackedScalarRunReceives(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef run, - const BatchRunSendPlan& plan, - Type fragmentType, - Location loc) { - MaterializedClass& targetClass = state.classes[plan.destinationClass]; - assert(!targetClass.isBatch && "packed scalar run receives expect a scalar target class"); - - size_t laneCount = sourceClass.cpus.size(); - size_t receiveCount = run.size() * laneCount; - - if (failed(plan.messages.verify(targetClass.op))) - return failure(); - - if (receiveCount != plan.messages.size()) - return targetClass.op->emitError("inconsistent flattened batch run receive plan"); - - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return targetClass.op->emitError("packed scalar run receive expects static ranked fragment type"); - - PackedScalarRunValue packedRun; - packedRun.targetClass = targetClass.id; - packedRun.sourceOp = run.front().peers.front().op; - packedRun.resultIndex = plan.resultIndex; - packedRun.kind = PackedScalarRunKind::DeferredReceive; - packedRun.fragmentType = rankedFragmentType; - - packedRun.messages = plan.messages; - - packedRun.slots.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { - PackedScalarRunSlot packedSlot; - packedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); - packedRun.slots.push_back(std::move(packedSlot)); - } - - if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) - return failure(); - - state.availableValues.recordPackedRun(std::move(packedRun)); - return success(); -} - -LogicalResult recordIndexedBatchRunReceives(MaterializerState& state, - ArrayRef run, - const BatchRunSendPlan& plan, - Type fragmentType) { - MaterializedClass& targetClass = state.classes[plan.destinationClass]; - auto rankedFragmentType = dyn_cast(fragmentType); - if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) - return targetClass.op->emitError("indexed batch run receive expects static ranked fragment type"); - - IndexedBatchRunValue indexedRun; - indexedRun.targetClass = targetClass.id; - indexedRun.sourceOp = run.front().peers.front().op; - indexedRun.resultIndex = plan.resultIndex; - indexedRun.fragmentType = rankedFragmentType; - indexedRun.messages = plan.messages; - indexedRun.slots.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { - PackedScalarRunSlot indexedSlot; - indexedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); - indexedRun.slots.push_back(std::move(indexedSlot)); - } - - state.availableValues.recordIndexedBatchRun(std::move(indexedRun)); - return success(); -} - -LogicalResult appendBatchRunReceives(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef run, - const BatchRunSendPlan& plan, - Type fragmentType, - Location loc) { - MaterializedClass& targetClass = state.classes[plan.destinationClass]; - - if (!targetClass.isBatch) - return appendPackedScalarRunReceives(state, sourceClass, run, plan, fragmentType, loc); - return recordIndexedBatchRunReceives(state, run, plan, fragmentType); -} - -LogicalResult materializeBatchClassRun(MaterializerState& state, - MaterializedClass& targetClass, - SlotId startSlot, - ArrayRef run) { - assert(targetClass.isBatch && "batch-target run materialization expects a materialized batch class"); - assert(!run.empty() && "expected non-empty batch-target run"); - - if (!canCompactBatchClassRun(state, targetClass, run)) - return failure(); - - markMaterializationRunSlots(state, targetClass.id, startSlot, run); - - SmallVector groups = groupBatchRunOutputsByDestination(state, run); - - auto sourceBatch = cast(run.front().peers.front().op); - SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); - Location loc = sourceBatch.getLoc(); - - for (const OutputDestinationGroup& group : groups) { - SmallVector sendPlans; - if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) - return failure(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { - Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - getScheduledChunkForLogicalInstance(state, run.front().peers.front()), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); - if (failed(produced)) - return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); - if (resultIt == group.resultIndices.end()) - return failure(); - - size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); - appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); - } - return success(); - }); - if (failed(loop)) - return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) - return failure(); - - if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) - return failure(); - } - } - - return success(); -} - -LogicalResult materializeInstanceSlot(MaterializerState& state, - const ComputeInstance& instance) { - auto cpuIt = state.schedule.computeToCpuMap.find(instance); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - auto logicalRangeIt = state.scheduledInstanceToLogicalSlots.find(instance); - if (logicalRangeIt == state.scheduledInstanceToLogicalSlots.end()) - return instance.op->emitError("schedule materialization expected logical slots for every compute instance"); - - ClassId classId = state.cpuToClass.lookup(cpuIt->second); - MaterializedClass& targetClass = state.classes[classId]; - - LogicalSlotRange logicalRange = logicalRangeIt->second; - SlotId startLogicalSlot = logicalRange.start; - while (startLogicalSlot < logicalRange.start + logicalRange.count - && state.materializedLogicalSlots.contains({classId, startLogicalSlot})) { - ++startLogicalSlot; - } - if (startLogicalSlot == logicalRange.start + logicalRange.count) - return success(); - - if (isa(instance.op)) { - FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); - - if (succeeded(run)) { - if (!targetClass.isBatch) - return materializeScalarBatchRun(state, targetClass, startLogicalSlot, *run); - - if (succeeded(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run))) - return success(); - } - } - - if (!state.materializedLogicalSlots.insert({classId, startLogicalSlot}).second) - return success(); - - FailureOr> peers = - getMaterializationRunSlotPeers(state, targetClass, startLogicalSlot); - if (failed(peers)) - return instance.op->emitError("failed to collect peer compute instances for equivalence class logical slot"); - - Value projectionSlotIndex = getOrCreateIndexConstant( - state.constantFolder, targetClass.op, static_cast(startLogicalSlot - logicalRange.start)); - FailureOr> materializedOutputs = - cloneInstanceBody(state, - targetClass, - *peers, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = projectionSlotIndex}); - if (failed(materializedOutputs)) - return failure(); - - ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, instance); - if (materializedOutputs->size() != originalOutputs.size()) - return instance.op->emitError("materialized output count does not match original compute instance output count"); - - for (auto [resultIndex, zipped] : llvm::enumerate(llvm::zip(*materializedOutputs, originalOutputs))) { - Value materializedOutput = std::get<0>(zipped); - Value originalOutput = std::get<1>(zipped); - MaterializationRunSlot slot; - slot.peers = *peers; - SmallVector keys = getMaterializationRunSlotOutputKeys(slot, resultIndex); - if (failed(emitOutputFanout(state, targetClass, keys, materializedOutput, originalOutput, instance.op->getLoc()))) - return failure(); - } - - return success(); -} - -FailureOr createReceiveConcatLoop(MaterializerState& state, - MaterializedClass& targetClass, - RankedTensorType concatType, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); - assert(!messages.empty() && "expected at least one receive"); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); - return emitIndexedFragmentInsertLoop( - state, - targetClass, - init, - static_cast(messages.size()), - [&](Value index) -> FailureOr { - Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc); - return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - }, - [&](Value index) -> FailureOr { - return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc); - }, - loc); -} - -bool valueMayEvaluateToCore(Value value, int64_t coreId) { - if (std::optional constant = getConstantIndexValue(value)) - return *constant == coreId; - - auto affineApply = value.getDefiningOp(); - if (!affineApply) - return false; - - AffineMap map = affineApply.getAffineMap(); - if (map.getNumResults() != 1 || map.getNumDims() != 1 || map.getNumSymbols() != 0 - || affineApply.getMapOperands().size() != 1) - return false; - - auto iv = dyn_cast(affineApply.getMapOperands().front()); - if (!iv) - return false; - - auto loop = dyn_cast_or_null(iv.getOwner()->getParentOp()); - if (!loop || loop.getInductionVar() != iv) - return false; - - std::optional lower = getConstantIndexValue(loop.getLowerBound()); - std::optional upper = getConstantIndexValue(loop.getUpperBound()); - std::optional step = getConstantIndexValue(loop.getStep()); - if (!lower || !upper || !step || *step <= 0) - return false; - - for (int64_t iteration = *lower; iteration < *upper; iteration += *step) { - FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef{iteration}); - if (succeeded(evaluated) && *evaluated == coreId) - return true; - } - - return false; -} - -bool operationContainsReceiveFromPeer(Operation& op, int64_t localCore, int64_t peerCore, Type payloadType) { - bool found = false; - op.walk([&](SpatChannelReceiveOp receive) { - if (receive.getOutput().getType() != payloadType) - return; - if (!valueMayEvaluateToCore(receive.getTargetCoreId(), localCore)) - return; - if (!valueMayEvaluateToCore(receive.getSourceCoreId(), peerCore)) - return; - found = true; - }); - return found; -} - -LogicalResult orderLowerCoreScalarSendsAfterMatchingReceives(MaterializerState& state) { - for (MaterializedClass& materializedClass : state.classes) { - if (materializedClass.isBatch || materializedClass.cpus.empty()) - continue; - - int64_t localCore = static_cast(materializedClass.cpus.front()); - Block* body = materializedClass.body; - if (!body) - continue; - - bool changed = true; - while (changed) { - changed = false; - for (Operation& op : llvm::make_early_inc_range(*body)) { - if (&op == body->getTerminator()) - break; - - auto send = dyn_cast(&op); - if (!send) - continue; - - std::optional sourceCore = getConstantIndexValue(send.getSourceCoreId()); - std::optional targetCore = getConstantIndexValue(send.getTargetCoreId()); - if (!sourceCore || !targetCore || *sourceCore != localCore || *sourceCore >= *targetCore) - continue; - - Operation* matchingReceiveContainer = nullptr; - for (Operation* candidate = op.getNextNode(); candidate && candidate != body->getTerminator(); - candidate = candidate->getNextNode()) { - if (operationContainsReceiveFromPeer(*candidate, localCore, *targetCore, send.getInput().getType())) { - matchingReceiveContainer = candidate; - break; - } - } - - if (!matchingReceiveContainer) - continue; - - op.moveAfter(matchingReceiveContainer); - changed = true; - break; - } - } - } - - return success(); -} - -void replaceHostUses(MaterializerState& state) { - for (const auto& [oldValue, replacement] : state.hostReplacements) - replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); -} - -LogicalResult eraseOldComputeOps(MaterializerState& state) { - DenseSet seen; - for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { - if (!seen.insert(instance.op).second) - continue; - instance.op->dropAllUses(); - instance.op->erase(); - } - return success(); -} - -} // namespace - -LogicalResult -MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { - if (schedule.dominanceOrderCompute.empty()) - return success(); - - MaterializerState state(func, schedule, nextChannelId); - if (failed(buildMaterializationWorkStreams(state))) - return failure(); - if (failed(buildMaterializationClassesFromScheduleEquivalence(state))) - return failure(); - if (failed(verifyScheduleEquivalenceMatchesLogicalStreams(state))) - return failure(); - if (state.classes.empty()) - return success(); - - if (failed(collectHostOutputs(state))) - return failure(); - if (failed(createEmptyMaterializedOps(state))) - return failure(); - if (failed(collectProducerDestinations(state))) - return failure(); - if (failed(collectProjectedTransfers(state))) - return failure(); - - for (const ComputeInstance& instance : schedule.dominanceOrderCompute) - if (failed(materializeInstanceSlot(state, instance))) - return failure(); - - if (failed(finalizeProjectedHostOutputFragments(state))) - return failure(); - if (failed(orderLowerCoreScalarSendsAfterMatchingReceives(state))) - return failure(); - - for (MaterializedClass& materializedClass : state.classes) - if (failed(localizeAllScheduledBodyCaptures(state, materializedClass))) - return failure(); - - replaceHostUses(state); - if (failed(eraseOldComputeOps(state))) - return failure(); - - LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); - (void) _; - - return success(); -} - -} // namespace spatial -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej deleted file mode 100644 index 3abd678..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej +++ /dev/null @@ -1,128 +0,0 @@ ---- src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.043731129 +0000 -+++ src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.026726895 +0000 -@@ -4112,104 +4112,8 @@ - Value originalOutput, - Location loc); - --FailureOr> rematerializeProjectionIndexListForBatchHostOutput( -- MaterializerState& state, -- MaterializedClass& sourceClass, -- ArrayRef values, -- IRMapping& mapper, -- Location loc) { -- SmallVector localized; -- localized.reserve(values.size()); -- for (OpFoldResult value : values) { -- FailureOr remapped = -- rematerializeIndexOpFoldResultInClass(state, sourceClass, value, loc, &mapper); -- if (failed(remapped)) -- return failure(); -- localized.push_back(*remapped); -- } -- return localized; --} -- --LogicalResult createProjectionAwareBatchHostInsert(MaterializerState& state, -- MaterializedClass& sourceClass, -- Value originalOutput, -- Value payload, -- Value destination, -- ArrayRef keys, -- Location loc) { -- auto originalResult = dyn_cast(originalOutput); -- if (!originalResult) -- return failure(); -- -- auto sourceBatch = dyn_cast_or_null(originalResult.getOwner()); -- if (!sourceBatch || sourceBatch.getNumResults() == 0) -- return failure(); -- -- FailureOr projection = -- getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); -- if (failed(projection)) -- return failure(); -- -- auto sourceLaneArg = sourceBatch.getLaneArgument(); -- if (!sourceLaneArg) -- return failure(); -- -- auto materializedBatch = dyn_cast(sourceClass.op); -- if (!materializedBatch) -- return failure(); -- -- auto materializedLaneArg = materializedBatch.getLaneArgument(); -- if (!materializedLaneArg) -- return failure(); -- -- if (keys.size() != sourceClass.cpus.size()) -- return failure(); -- -- SmallVector logicalLanes; -- logicalLanes.reserve(keys.size()); -- for (ProducerKey key : keys) { -- if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != originalResult.getResultNumber()) -- return failure(); -- logicalLanes.push_back(key.instance.laneStart); -- } -- -- IRMapping mapper; -- Value logicalLane = createIndexedIndexValue(state, -- sourceClass.op, -- ArrayRef(logicalLanes), -- *materializedLaneArg, -- loc, -- static_cast(sourceClass.cpus.size()), -- /*allowExhaustiveTiledSearch=*/false); -- mapper.map(*sourceLaneArg, logicalLane); -- -- FailureOr> offsets = -- rematerializeProjectionIndexListForBatchHostOutput( -- state, sourceClass, projection->getMixedOffsets(), mapper, loc); -- if (failed(offsets)) -- return failure(); -- FailureOr> sizes = -- rematerializeProjectionIndexListForBatchHostOutput( -- state, sourceClass, projection->getMixedSizes(), mapper, loc); -- if (failed(sizes)) -- return failure(); -- FailureOr> strides = -- rematerializeProjectionIndexListForBatchHostOutput( -- state, sourceClass, projection->getMixedStrides(), mapper, loc); -- if (failed(strides)) -- return failure(); -- -- tensor::ParallelInsertSliceOp::create( -- state.rewriter, loc, payload, destination, *offsets, *sizes, *strides); -- return success(); --} -- - LogicalResult --setHostOutputValue(MaterializerState& state, -- MaterializedClass& sourceClass, -- Value originalOutput, -- Value payload, -- ArrayRef keys = {}) { -+setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { - auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); - if (resultIt == sourceClass.hostOutputToResultIndex.end()) - return sourceClass.op->emitError("missing host result slot for materialized output") -@@ -4253,10 +4157,6 @@ - return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); - - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); -- if (succeeded(createProjectionAwareBatchHostInsert( -- state, sourceClass, originalOutput, payload, *outputArg, keys, payload.getLoc()))) -- return success(); -- - createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); - return success(); - } -@@ -4276,7 +4176,7 @@ - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (sourceClass.id == ownerClass.id) -- return setHostOutputValue(state, ownerClass, originalOutput, payload, keys); -+ return setHostOutputValue(state, ownerClass, originalOutput, payload); - - // Keep the old deadlock-free communication discipline: only scalar-to-scalar - // host-owner forwarding is introduced here. Batch host publication remains on diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 6f718bd..6793437 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -354,13 +354,13 @@ public: void runOnOperation() override { func::FuncOp func = getOperation(); if (failed(verifyLogicalSpatialGraphInvariants(func))) { - func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes"); + func.emitOpError("logical Spatial graph verification failed at the start of MergeComputeNodes"); signalPassFailure(); return; } mergeTriviallyConnectedComputes(func); if (failed(verifyLogicalSpatialGraphInvariants(func))) { - func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification"); + func.emitOpError("logical Spatial graph verification failed after trivial merge simplification"); signalPassFailure(); return; } @@ -378,7 +378,7 @@ public: return; } if (failed(verifyScheduledSpatialInvariants(func))) { - func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization"); + func.emitOpError("scheduled Spatial verification failed after merge materialization"); signalPassFailure(); return; }