From 62dd40ee8925a0c9956f8bd25ca4afc72521d43d Mon Sep 17 00:00:00 2001 From: ilgeco Date: Wed, 24 Jun 2026 15:52:07 +0200 Subject: [PATCH] DeadLock --- .../src/lib/instruction_set/isa.rs | 21 +- backend-simulators/pim/pimsim-nn | 2 +- onnx-mlir | 2 +- src/PIM/Common/IR/AddressAnalysis.cpp | 61 + src/PIM/Compiler/PimCompilerOptions.cpp | 18 + src/PIM/Compiler/PimCompilerOptions.hpp | 3 + src/PIM/Compiler/PimCompilerUtils.cpp | 2 + .../Conversion/ONNXToSpatial/CMakeLists.txt | 2 + .../Common/ComputeRegionBuilder.cpp | 2 +- .../Common/ComputeRegionBuilder.hpp | 101 +- .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 6 +- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 6 +- .../ONNXToSpatial/LowerSpatialPlansPass.cpp | 409 ++ .../ONNXToSpatial/ONNXToSpatialPass.cpp | 47 +- .../ONNXToSpatial/ONNXToSpatialVerifier.cpp | 216 +- .../ONNXToSpatial/ONNXToSpatialVerifier.hpp | 4 +- src/PIM/Conversion/ONNXToSpatial/Patterns.hpp | 4 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 887 ++- .../ONNXToSpatial/Patterns/NN/Relu.cpp | 9 +- .../ONNXToSpatial/Patterns/Post.cpp | 20 +- .../Conversion/ONNXToSpatial/PlanLowering.hpp | 21 + .../SpatialLayoutPlanningPass.cpp | 200 + .../BatchCoreLoweringPatterns.cpp | 4 +- .../SpatialToPim/ComputeLikeRegionUtils.cpp | 12 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 10 +- .../SpatialToPim/Patterns/ChannelLowering.cpp | 18 +- .../Patterns/GlobalTensorMaterialization.cpp | 14 +- .../SpatialToPim/ReturnPathNormalization.cpp | 10 +- .../SpatialToPim/SpatialToPimPass.cpp | 80 +- .../SpatialToPim/SpatialToPimPass.hpp | 9 +- .../Bufferization/BufferizationUtils.cpp | 11 +- .../Bufferization/BufferizationUtils.hpp | 7 +- .../OpBufferizationInterfaces.cpp | 47 +- .../Bufferization/PimBufferizationPass.cpp | 60 + .../HostConstantFolding/Patterns/Constant.cpp | 3 +- .../Transforms/Verification/CMakeLists.txt | 1 + .../Verification/VerificationPass.cpp | 487 +- src/PIM/Dialect/Spatial/Spatial.td | 133 +- src/PIM/Dialect/Spatial/SpatialOps.cpp | 353 +- src/PIM/Dialect/Spatial/SpatialOps.hpp | 16 + src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 511 +- .../Spatial/SpatialOpsCanonicalization.cpp | 15 +- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 313 +- .../MaterializeMergeSchedule.cpp | 4908 ++++++++++++++++- .../MergeComputeNodesPass.cpp | 24 +- src/PIM/Pass/PIMPasses.h | 2 + src/PIM/PimAccelerator.cpp | 2 + 47 files changed, 7993 insertions(+), 1100 deletions(-) create mode 100644 src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp diff --git a/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs b/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs index ca84b9a..2d46907 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs @@ -258,24 +258,23 @@ where let (memory, crossbars) = core.get_memory_crossbar(); let crossbar = crossbars.get_mut(group).unwrap(); - let crossbar_stored_bytes = crossbar.stored_bytes(); - let crossbar_byte_width = crossbar.width(); - - let crossbar_elem_width = crossbar_byte_width / size_of::(); - ensure!( - crossbar_byte_width % size_of::() == 0, - "M not divisor of the crosbbar size" - ); - let crossbar_height = crossbar.height(); - let crossbar_byte_size = crossbar_byte_width * crossbar_height; + let crossbar_stored_bytes = crossbar.stored_bytes(); + let bytes_per_column = crossbar_height * size_of::(); + ensure!(bytes_per_column != 0, "crossbar height can not be zero"); + ensure!( + crossbar_stored_bytes % bytes_per_column == 0, + "Stored crossbar bytes do not describe an integral number of columns" + ); + let crossbar_elem_width = crossbar_stored_bytes / bytes_per_column; + ensure!(crossbar_elem_width != 0, "Crossbar contains no stored columns"); let loads = memory .reserve_load(r1_val, crossbar_height * size_of::())? .execute_load::()?; let load = loads[0]; let vec: Cow<[M]> = load.up(); - let matrix = crossbar.load::(crossbar_byte_size)?[0]; + let matrix = crossbar.load::(crossbar_stored_bytes)?[0]; // --- FAER IMPLEMENTATION --- diff --git a/backend-simulators/pim/pimsim-nn b/backend-simulators/pim/pimsim-nn index 6d3b898..3e3442b 160000 --- a/backend-simulators/pim/pimsim-nn +++ b/backend-simulators/pim/pimsim-nn @@ -1 +1 @@ -Subproject commit 6d3b898e6b191c4446dfcc8c085ba1e50125e942 +Subproject commit 3e3442b66354282e600c5c45990af0e92aecf0f9 diff --git a/onnx-mlir b/onnx-mlir index eb54c2a..82018d7 160000 --- a/onnx-mlir +++ b/onnx-mlir @@ -1 +1 @@ -Subproject commit eb54c2afc46d00c6b196d1f275b6bfee17e12f69 +Subproject commit 82018d7ce59c94bfbe9479b16538224969fa45a0 diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index a46e361..6e6add0 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -56,6 +56,22 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); } + if (auto forOp = mlir::dyn_cast(definingOp)) { + auto result = mlir::dyn_cast(value); + if (result) { + auto yieldOp = mlir::dyn_cast(forOp.getBody()->getTerminator()); + if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) { + mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); + if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { + if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 + && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) + return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); + } + return yieldedValue; + } + } + } + if (auto castOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); if (auto collapseOp = mlir::dyn_cast(definingOp)) @@ -512,6 +528,24 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va continue; } + if (auto ifOp = mlir::dyn_cast(definingOp)) { + auto result = mlir::dyn_cast(value); + if (!result) + return mlir::failure(); + + auto condition = resolveIndexValueImpl(ifOp.getCondition(), knowledge); + if (failed(condition)) + return mlir::failure(); + + mlir::Region& selectedRegion = *condition != 0 ? ifOp.getThenRegion() : ifOp.getElseRegion(); + auto yieldOp = mlir::dyn_cast(selectedRegion.front().getTerminator()); + if (!yieldOp || result.getResultNumber() >= yieldOp.getNumOperands()) + return mlir::failure(); + + value = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); + continue; + } + if (auto subviewOp = mlir::dyn_cast(definingOp)) { auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); auto subviewType = mlir::dyn_cast(subviewOp.getType()); @@ -622,6 +656,33 @@ llvm::FailureOr compileContiguousAddressExprImpl(mlir::Valu continue; } + if (auto ifOp = mlir::dyn_cast(definingOp)) { + auto result = mlir::dyn_cast(value); + if (!result) + return mlir::failure(); + + auto thenYield = mlir::dyn_cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = mlir::dyn_cast(ifOp.getElseRegion().front().getTerminator()); + if (!thenYield || !elseYield || result.getResultNumber() >= thenYield.getNumOperands() + || result.getResultNumber() >= elseYield.getNumOperands()) { + return mlir::failure(); + } + + auto thenAddress = compileContiguousAddressExprImpl(thenYield.getOperand(result.getResultNumber())); + auto elseAddress = compileContiguousAddressExprImpl(elseYield.getOperand(result.getResultNumber())); + if (failed(thenAddress) || failed(elseAddress) || thenAddress->base != elseAddress->base) + return mlir::failure(); + + auto condition = compileIndexValueImpl(ifOp.getCondition()); + if (failed(condition)) + return mlir::failure(); + + CompiledIndexExprNode selectExpr; + selectExpr.kind = CompiledIndexExprNode::Kind::Select; + selectExpr.operands = {*condition, thenAddress->byteOffset, elseAddress->byteOffset}; + return CompiledAddressExpr {thenAddress->base, makeCompiledIndexExpr(std::move(selectExpr))}; + } + if (auto subviewOp = mlir::dyn_cast(definingOp)) { auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); auto subviewType = mlir::dyn_cast(subviewOp.getType()); diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index f2a4d60..4fed7cb 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -96,6 +96,24 @@ llvm::cl::opt pimEmitJson("pim-emit-json", llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt pimDetectCommunicationDeadlock( + "pim-detect-communication-deadlock", + llvm::cl::desc("Expensively simulate the statically expanded PIM send/receive order at verification time and fail if a blocking communication deadlock is found"), + llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt pimMaterializeScalarFanoutGlobalOrder( + "pim-materialize-scalar-fanout-global-order", + llvm::cl::desc("Experimental expensive materializer mode: emit scalar-source fanout as globally ordered communication events instead of all-send fanout loops"), + llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt pimTraceCommunicationMaterialization( + "pim-trace-communication-materialization", + llvm::cl::desc("Emit verbose materializer-time diagnostics and provenance attributes for every Spatial communication op"), + llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2)); diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index 5fc77fb..51aa469 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -53,6 +53,9 @@ extern llvm::cl::opt pimDisableMemoryCoalescing; extern llvm::cl::opt useExperimentalConvImpl; extern llvm::cl::opt pimEmitJson; extern llvm::cl::opt pimReportConvLowering; +extern llvm::cl::opt pimDetectCommunicationDeadlock; +extern llvm::cl::opt pimMaterializeScalarFanoutGlobalOrder; +extern llvm::cl::opt pimTraceCommunicationMaterialization; extern llvm::cl::opt crossbarSize; extern llvm::cl::opt crossbarCountInCore; diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index e9bc397..852912f 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -29,6 +29,8 @@ void addPassesPim(OwningOpRef& module, if (pimEmissionTarget >= EmitSpatial) { pm.addPass(createONNXToSpatialPass()); + pm.addPass(createSpatialLayoutPlanningPass()); + pm.addPass(createLowerSpatialPlansPass()); pm.addPass(createMergeComputeNodesPass()); pm.addPass(createMessagePass("Onnx lowered to Spatial")); } diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 32964aa..1ddb1fc 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial Patterns/Tensor/Split.cpp Patterns/Tensor/Transpose.cpp ONNXToSpatialPass.cpp + SpatialLayoutPlanningPass.cpp + LowerSpatialPlansPass.cpp Common/AttributeUtils.cpp Common/ComputeRegionBuilder.cpp Common/IndexingUtils.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp index f39998a..62823a2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.cpp @@ -9,7 +9,7 @@ using namespace mlir; namespace onnx_mlir { -Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { +Value sumTensors(ArrayRef tensors, PatternRewriter& rewriter) { if (tensors.size() == 1) return tensors[0]; diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp index 27d6471..78f7328 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp @@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); } -/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if +/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if /// the body callback reports failure. template -auto createSpatCompute(RewriterT& rewriter, - mlir::Location loc, - mlir::TypeRange resultTypes, - mlir::ValueRange weights, - mlir::ValueRange inputs, - BodyFn&& body) { +auto createSpatGraphCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); - auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); + auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value weight : weights) @@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter, if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); - return mlir::FailureOr(computeOp); + return mlir::FailureOr(computeOp); } } -/// Builds a `spat.compute` whose body consumes the block arguments as a single +/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single /// `ValueRange`, which is convenient for variadic reductions/concats. template -auto createSpatCompute(RewriterT& rewriter, - mlir::Location loc, - mlir::TypeRange resultTypes, - mlir::ValueRange weights, - mlir::ValueRange inputs, - BodyFn&& body) { - auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); +auto createSpatGraphCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value weight : weights) @@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter, if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); - return mlir::FailureOr(computeOp); + return mlir::FailureOr(computeOp); } } template -auto createSpatComputeBatch(RewriterT& rewriter, - mlir::Location loc, - mlir::TypeRange resultTypes, - int64_t laneCount, - mlir::ValueRange weights, - mlir::ValueRange inputs, - BodyFn&& body) { +auto createSpatGraphComputeBatch(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + int64_t laneCount, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { if (laneCount <= 0 || laneCount > std::numeric_limits::max()) - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count"); if (mlir::failed(laneCountAttr)) - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); - auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs); + auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs); mlir::SmallVector blockArgTypes {rewriter.getIndexType()}; mlir::SmallVector blockArgLocs {loc}; @@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter, if constexpr (std::is_same_v) { std::forward(body)(args); rewriter.setInsertionPointAfter(batchOp); - return mlir::FailureOr(batchOp); + return mlir::FailureOr(batchOp); } else { auto bodyResult = std::forward(body)(args); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(batchOp); rewriter.eraseOp(batchOp); - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(batchOp); - return mlir::FailureOr(batchOp); + return mlir::FailureOr(batchOp); } } +template +auto createSpatCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + return createSpatGraphCompute( + rewriter, loc, resultTypes, weights, inputs, std::forward(body)); +} + +template +auto createSpatCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward(body)); +} + +template +auto createSpatComputeBatch(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + int64_t laneCount, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + return createSpatGraphComputeBatch( + rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward(body)); +} + inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, @@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input, return computeOp.getResult(0); } -mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); +mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::PatternRewriter& rewriter); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 771ca03..d97da5a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -83,7 +83,7 @@ SmallVector getStaticSizes(PatternRewriter& rewriter, ArrayRef sliceTensor( - const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { + const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(tensorToSlice); assert("Invalid axis" && axis < shape.size()); @@ -129,7 +129,7 @@ SmallVector sliceTensor( } SmallVector -sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { +sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(vectorToSlice); assert("Not a vector" && isVectorShape(shape)); size_t axis = shape[0] != 1 ? 0 : 1; @@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr } DenseMap> -sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) { +sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) { SmallVector slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc); DenseMap> slicesPerCore; for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) { diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index 4c265c1..b803969 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -89,18 +89,18 @@ llvm::SmallVector getStaticSizes(mlir::PatternRewriter& rewr llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, size_t axis, int64_t sliceSize, - mlir::ConversionPatternRewriter& rewriter, + mlir::PatternRewriter& rewriter, mlir::Location loc); llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, int64_t sliceSize, - mlir::ConversionPatternRewriter& rewriter, + mlir::PatternRewriter& rewriter, mlir::Location loc); /// Partitions one logical vector into per-core crossbar-sized slices using the /// current PIM target geometry. llvm::DenseMap> sliceVectorPerCrossbarPerCore( - const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); + const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc); mlir::Value extractAxisSlice( mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); diff --git a/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp b/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp new file mode 100644 index 0000000..be74eb2 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp @@ -0,0 +1,409 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" + +#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PIMPasses.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static constexpr StringLiteral kDenseLayout = "dense_nchw"; +static constexpr StringLiteral kRowStripLayout = "nchw_row_strip"; + +struct RowStripPhysicalValue { + Value physicalValue; + RankedTensorType logicalType; + SmallVector fragmentOffsets; + SmallVector fragmentSizes; + std::string indexMap; +}; + +static FailureOr getRowStripValue(llvm::DenseMap& rowStripValues, + Value value) { + auto it = rowStripValues.find(value); + if (it == rowStripValues.end()) + return failure(); + return it->second; +} + +static FailureOr buildRowStripValue(spatial::SpatReconciliatorOp reconciliator, + Value physicalValue) { + auto logicalType = dyn_cast(reconciliator.getOutput().getType()); + if (!logicalType) + return reconciliator.emitOpError("requires ranked logical output type"), failure(); + RowStripPhysicalValue value; + value.physicalValue = physicalValue; + value.logicalType = logicalType; + value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end()); + value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end()); + value.indexMap = reconciliator.getIndexMap().str(); + return value; +} + +static FailureOr +lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) { + auto packedType = cast(input.physicalValue.getType()); + auto computeOp = + createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) { + auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x); + spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult()); + }); + return computeOp.getResult(0); +} + +static FailureOr +materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) { + auto packedType = dyn_cast(rowStripValue.physicalValue.getType()); + if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape()) + return failure(); + if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape()) + return failure(); + if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw") + return failure(); + + const int64_t rank = rowStripValue.logicalType.getRank(); + const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank; + const int64_t packedWidth = packedType.getDimSize(1); + const int64_t packedChannels = packedType.getDimSize(2); + if (fragmentCount != packedType.getDimSize(0)) + return failure(); + for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) { + if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0 + || rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0 + || rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex + || rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0) + return failure(); + if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1 + || rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels + || rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1 + || rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth) + return failure(); + } + + auto packedSliceType = + RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding()); + auto expandedType = + RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding()); + auto logicalFragmentType = + RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding()); + auto batchOp = createSpatComputeBatch( + rewriter, + loc, + TypeRange {rowStripValue.logicalType}, + fragmentCount, + {}, + ValueRange {rowStripValue.physicalValue}, + [&](detail::SpatComputeBatchBodyArgs args) { + SmallVector packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector packedSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)}; + Value packedSlice = tensor::ExtractSliceOp::create( + rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3)); + + Value expanded = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + packedSlice, + SmallVector { + {0, 1}, + {2}, + {3} + }); + Value transposeInit = + tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType()); + Value logicalFragment = + linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector {0, 3, 1, 2}) + .getResult()[0]; + + SmallVector logicalOffsets { + rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)}; + SmallVector logicalSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(packedChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(packedWidth)}; + createParallelInsertSliceIntoBatchOutput(rewriter, + loc, + logicalFragment, + args.outputs.front(), + logicalOffsets, + logicalSizes, + getUnitStrides(rewriter, 4)); + return success(); + }); + if (failed(batchOp)) + return failure(); + return batchOp->getResult(0); +} + +struct LowerSpatialPlansPass final : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass) + + StringRef getArgument() const override { return "lower-spatial-plans"; } + StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + MLIRContext* ctx = moduleOp.getContext(); + auto entryFunc = getPimEntryFunc(moduleOp); + if (failed(entryFunc)) { + moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans"); + signalPassFailure(); + return; + } + func::FuncOp funcOp = *entryFunc; + PatternRewriter rewriter(ctx); + llvm::DenseMap rowStripValues; + llvm::SmallPtrSet eraseAfterLowering; + auto verifyLogicalPhase = [&](StringRef stage) -> bool { + if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc))) + return true; + moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage; + signalPassFailure(); + return false; + }; + + if (!verifyLogicalPhase("at the start of LowerSpatialPlans")) + return; + for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) { + if (auto planOp = dyn_cast(&op)) { + FailureOr rowStripInput = getRowStripValue(rowStripValues, planOp.getInput()); + auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { + auto reconciliator = dyn_cast(user); + return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout; + }); + if (rowStripReconciliator != planOp.getResult().getUsers().end()) { + rewriter.setInsertionPoint(planOp); + FailureOr lowered = lowerSelectedConv2DPlan( + planOp, + succeeded(rowStripInput) ? std::optional {rowStripInput->physicalValue} : std::nullopt, + /*emitRowStripLayout=*/true, + rewriter); + if (failed(lowered)) { + planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan"); + signalPassFailure(); + return; + } + auto reconciliator = cast(*rowStripReconciliator); + FailureOr rowStripValue = buildRowStripValue(reconciliator, *lowered); + if (failed(rowStripValue)) { + signalPassFailure(); + return; + } + rowStripValues[reconciliator.getResult()] = *rowStripValue; + eraseAfterLowering.insert(planOp); + eraseAfterLowering.insert(reconciliator); + continue; + } + rewriter.setInsertionPoint(planOp); + FailureOr lowered = + lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter); + if (failed(lowered)) { + planOp.emitOpError("failed to lower selected Spatial Conv plan"); + signalPassFailure(); + return; + } + rewriter.replaceOp(planOp, *lowered); + continue; + } + + if (auto planOp = dyn_cast(&op)) { + if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) { + auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { + auto reconciliator = dyn_cast(user); + return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout; + }); + if (outputReconciliator == planOp.getResult().getUsers().end()) { + planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result"); + signalPassFailure(); + return; + } + + FailureOr input = getRowStripValue(rowStripValues, planOp.getInput()); + rewriter.setInsertionPoint(planOp); + FailureOr lowered = lowerRowStripRelu(*input, planOp, rewriter); + if (failed(lowered)) { + planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan"); + signalPassFailure(); + return; + } + auto reconciliator = cast(*outputReconciliator); + FailureOr output = buildRowStripValue(reconciliator, *lowered); + if (failed(output)) { + signalPassFailure(); + return; + } + rowStripValues[reconciliator.getResult()] = *output; + eraseAfterLowering.insert(planOp); + eraseAfterLowering.insert(reconciliator); + continue; + } + + rewriter.setInsertionPoint(planOp); + auto computeOp = createSpatCompute<1>( + rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) { + auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x); + spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult()); + }); + rewriter.replaceOp(planOp, computeOp.getResults()); + continue; + } + if (auto materializeOp = dyn_cast(&op)) { + if (materializeOp.getSourcePhysicalLayout() == kDenseLayout + && materializeOp.getTargetPhysicalLayout() == kDenseLayout) { + rewriter.replaceOp(materializeOp, materializeOp.getInput()); + continue; + } + if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout + || materializeOp.getTargetPhysicalLayout() != kDenseLayout) { + materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet"); + signalPassFailure(); + return; + } + FailureOr rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput()); + if (failed(rowStripValue)) { + materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization"); + signalPassFailure(); + return; + } + rewriter.setInsertionPoint(materializeOp); + FailureOr dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter); + if (failed(dense)) { + materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW"); + signalPassFailure(); + return; + } + rewriter.replaceOp(materializeOp, *dense); + continue; + } + if (auto reconciliatorOp = dyn_cast(&op)) { + if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) { + rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput()); + continue; + } + if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) { + reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet"); + signalPassFailure(); + return; + } + if (!eraseAfterLowering.contains(reconciliatorOp)) { + reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans"); + signalPassFailure(); + return; + } + } + } + bool erasedAny = true; + while (erasedAny) { + erasedAny = false; + for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) { + if (!eraseAfterLowering.contains(&op)) + continue; + if (!op.use_empty()) + continue; + eraseAfterLowering.erase(&op); + rewriter.eraseOp(&op); + erasedAny = true; + } + } + if (!eraseAfterLowering.empty()) { + for (Operation& op : funcOp.getBody().front()) + if (eraseAfterLowering.contains(&op)) + op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans"); + signalPassFailure(); + return; + } + ConversionTarget helperTarget(*ctx); + helperTarget.addLegalDialect(); + helperTarget.addLegalOp(); + helperTarget.addIllegalOp(); + helperTarget.markOpRecursivelyLegal(); + + RewritePatternSet helperPatterns(ctx); + populateGemmPatterns(helperPatterns, ctx); + populateTransposePatterns(helperPatterns, ctx); + if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) { + moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering"); + signalPassFailure(); + return; + } + FrozenRewritePatternSet nestedHelperPatterns([&] { + RewritePatternSet patterns(ctx); + populateGemmPatterns(patterns, ctx); + populateTransposePatterns(patterns, ctx); + return patterns; + }()); + ConversionTarget nestedHelperTarget(*ctx); + nestedHelperTarget.addLegalDialect(); + nestedHelperTarget.addIllegalOp(); + SmallVector computeLikeOps; + funcOp.walk([&](Operation* op) { + if (isa(op)) + computeLikeOps.push_back(op); + }); + for (Operation* op : computeLikeOps) { + if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) { + op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering"); + signalPassFailure(); + return; + } + } + if (!verifyLogicalPhase("after nested helper conversions")) + return; + bool hasIllegalOps = false; + moduleOp.walk([&](Operation* op) { + if (isa(op)) + return; + if (isa(op) + || op->getDialect()->getNamespace() == "onnx") { + op->emitOpError("operation must not remain after LowerSpatialPlans"); + hasIllegalOps = true; + } + }); + if (hasIllegalOps) + signalPassFailure(); + else + dumpModule(moduleOp, "spatial1_premerge"); + + if (!verifyLogicalPhase("at the end of LowerSpatialPlans")) + return; + } +}; + +} // namespace + +std::unique_ptr createLowerSpatialPlansPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 531fe34..de87200 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -18,6 +18,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" +#include "ONNXToSpatialVerifier.hpp" using namespace mlir; @@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper computes(funcOp.getOps()); - SmallVector computeBatches(funcOp.getOps()); - if (!computes.empty() || !computeBatches.empty()) + SmallVector computes(funcOp.getOps()); + SmallVector computeBatches(funcOp.getOps()); + SmallVector convPlans(funcOp.getOps()); + SmallVector reluPlans(funcOp.getOps()); + SmallVector reconciliators(funcOp.getOps()); + SmallVector materializers(funcOp.getOps()); + if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty() + || !materializers.empty()) { return; + } auto returnOp = cast(funcOp.getFunctionBody().front().getTerminator()); rewriter.setInsertionPoint(returnOp); @@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) { sourceLocs.push_back(source.getLoc()); } - auto newCompute = spatial::SpatCompute::create( + auto newCompute = spatial::SpatGraphCompute::create( rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs); for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands())) @@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) { rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : funcOp.getOps()) - if (!isa(&op)) + if (!isa(&op)) rewriter.clone(op, mapper); auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); @@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) { yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); for (Operation& op : llvm::make_early_inc_range(funcOp.getOps())) - if (!isa(&op)) { + if (!isa(&op)) { op.dropAllUses(); rewriter.eraseOp(&op); } @@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() { return; } + if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { + moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion"); + signalPassFailure(); + return; + } ConversionTarget earlyPostTarget(*ctx); earlyPostTarget.addLegalDialect(); - postTarget.addDynamicallyLegalOp( - [](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); }); - postTarget.addDynamicallyLegalOp( - [](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); }); + postTarget.addDynamicallyLegalOp( + [](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); }); + postTarget.addDynamicallyLegalOp( + [](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); }); + if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { + moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites"); + signalPassFailure(); + return; + } RewritePatternSet postPatterns(ctx); populatePostPatterns(postPatterns, ctx); if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) { @@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() { populateEmptyFunction(*entryFunc); + if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { + moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial"); + signalPassFailure(); + return; + } dumpModule(moduleOp, "spatial0"); if (failed(verifyONNXToSpatial(*entryFunc))) { diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp index e191093..b6b5162 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp @@ -1,4 +1,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/Support/LLVM.h" #include "Common/IR/WeightUtils.hpp" @@ -13,6 +15,8 @@ namespace onnx_mlir { namespace { +constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK"; + void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { func.walk([&](Operation* op) { if (!hasWeightAlways(op)) @@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag continue; diagnostics.report(op, [&](Operation* illegalOp) { - illegalOp->emitOpError( - "weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights"); + illegalOp->emitOpError() + << kPhaseMarker + << " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights"; }); return; } }); } -Region* getParentRegion(Value value) { - if (auto blockArg = dyn_cast(value)) - return blockArg.getOwner()->getParent(); - if (Operation* definingOp = value.getDefiningOp()) - return definingOp->getParentRegion(); - return nullptr; +bool isRegionOrAncestorOf(Region& region, Region* candidate) { + return candidate && (®ion == candidate || region.isAncestor(candidate)); } -bool isDefinedInsideRegion(Value value, Region& region) { - Region* parentRegion = getParentRegion(value); - return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +bool isValueDefinedInsideRegion(Value value, Region& region) { + if (auto blockArg = dyn_cast(value)) + return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent()); + if (Operation* definingOp = value.getDefiningOp()) + return isRegionOrAncestorOf(region, definingOp->getParentRegion()); + return false; +} + +bool isLegalExternalCapture(Value value, Region& region) { + if (isValueDefinedInsideRegion(value, region)) + return true; + + Operation* definingOp = value.getDefiningOp(); + return definingOp && definingOp->hasTrait(); +} + +template +void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) { + Region& body = compute.getBody(); + body.walk([&](Operation* nestedOp) { + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value value = operand.get(); + if (isLegalExternalCapture(value, body)) + continue; + + Operation* definingOp = value.getDefiningOp(); + diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) { + InFlightDiagnostic diag = + illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #" + << operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef(); + diag << " (type " << value.getType() << ")"; + if (definingOp) + diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef(); + else if (auto blockArg = dyn_cast(value)) { + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diag.attachNote(owner->getLoc()) + << "external block argument belongs to " << owner->getName().getStringRef(); + } + }); + } + }); } bool isLegalHostBackedValue(Value value) { Operation* definingOp = value.getDefiningOp(); if (!definingOp) return isa(value); - - if (isa(definingOp)) - return false; - return definingOp->getDialect()->getNamespace() != "spat"; } -LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp, - ValueRange inputs, - bool allowChannelReceiveInputs, - StringRef kind, - pim::CappedDiagnosticReporter& diagnostics) { - for (auto [inputIndex, input] : llvm::enumerate(inputs)) { - unsigned currentInputIndex = inputIndex; +template +void verifyScheduledInputs(ComputeOpTy compute, + bool allowChannelReceiveInputs, + StringRef kind, + pim::CappedDiagnosticReporter& diagnostics) { + for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) { Operation* definingOp = input.getDefiningOp(); if (allowChannelReceiveInputs && isa_and_nonnull(definingOp)) continue; if (isLegalHostBackedValue(input)) continue; - diagnostics.report(computeLikeOp, [&](Operation* illegalOp) { - InFlightDiagnostic diagnostic = illegalOp->emitOpError() - << kind << " input #" << currentInputIndex - << (allowChannelReceiveInputs ? " must come from the host or an explicit " - "spat.channel_receive" - : " must come from the host"); + diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) { + InFlightDiagnostic diag = illegalOp->emitOpError() + << kPhaseMarker << " " << kind << " input #" << inputIndex + << (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive" + : " must come from the host"); if (definingOp) - diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName(); + diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef(); }); - return failure(); } - return success(); } -void verifyNoExternalTensorCaptures(Operation* ownerOp, - Region& region, - StringRef kind, - pim::CappedDiagnosticReporter& diagnostics) { - region.walk([&](Operation* op) { - for (OpOperand& operand : op->getOpOperands()) { - Value value = operand.get(); - if (!isa(value.getType())) - continue; - if (isDefinedInsideRegion(value, region) || isa(value)) - continue; +void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) { + for (Operation& op : funcOp.getOps()) { + if (isa(&op)) { + continue; + } + if (isa(&op)) { + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase"; + }); + continue; + } + if (isa(&op)) { + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() << kPhaseMarker + << " explicit channel communication is not expected before merge materialization"; + }); + continue; + } + if (isCompileTimeOp(&op)) + continue; - Operation* definingOp = value.getDefiningOp(); - if (definingOp && definingOp->hasTrait()) - continue; + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() + << kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute"; + }); + } +} - diagnostics.report(ownerOp, [&](Operation* illegalOp) { - InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor " - << "values"; - diagnostic.attachNote(op->getLoc()) - << "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by " - << (definingOp ? definingOp->getName().getStringRef() : StringRef("")); +void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) { + for (Operation& op : funcOp.getOps()) { + if (isa(&op)) { + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization"; }); } - }); + } } } // namespace -LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { +LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; - - for (Operation& op : funcOp.getOps()) { - if (isa(&op)) - continue; - if (isCompileTimeOp(&op)) - continue; - - diagnostics.report(&op, [](Operation* illegalOp) { - illegalOp->emitOpError( - "non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); - }); - } - checkWeightUseChains(funcOp, diagnostics); - diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed"); - + for (auto compute : funcOp.getOps()) + verifyComputeBodyCaptures(compute, "graph_compute", diagnostics); + for (auto batch : funcOp.getOps()) + verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics); + for (auto compute : funcOp.getOps()) + verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics); + for (auto batch : funcOp.getOps()) + verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics); + diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed"); return success(!diagnostics.hasFailure()); } -LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) { +LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); } + +LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; + verifyLogicalTopLevelOps(funcOp, diagnostics); + checkWeightUseChains(funcOp, diagnostics); + if (failed(verifyNoComputeBodyCaptures(funcOp))) + return failure(); + diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed"); + return success(!diagnostics.hasFailure()); +} - for (auto computeOp : funcOp.getOps()) { - (void) verifyComputeLikeInputs( - computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics); - verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics); - } - - for (auto computeBatchOp : funcOp.getOps()) { - (void) verifyComputeLikeInputs(computeBatchOp.getOperation(), - computeBatchOp.getInputs(), - /*allowChannelReceiveInputs=*/false, - "spat.compute_batch", - diagnostics); - verifyNoExternalTensorCaptures( - computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics); - } - - diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed"); +LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) { + pim::CappedDiagnosticReporter diagnostics; + verifyScheduledTopLevelOps(funcOp, diagnostics); + for (auto compute : funcOp.getOps()) + verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics); + for (auto batch : funcOp.getOps()) + verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics); + if (failed(verifyNoComputeBodyCaptures(funcOp))) + return failure(); + diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed"); return success(!diagnostics.hasFailure()); } diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp index 3ac5b9c..716f5d2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp @@ -6,6 +6,8 @@ namespace onnx_mlir { mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp); -mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp); +mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp); +mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp); +mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index c040536..e687e3d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -bool requiresPostRewrite(spatial::SpatCompute computeOp); -bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); +bool requiresPostRewrite(spatial::SpatGraphCompute computeOp); +bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp); void annotateWeightsConstants(mlir::func::FuncOp funcOp); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 569f404..14c2b80 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" @@ -24,6 +25,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -106,6 +108,19 @@ struct PreparedConvInput { RankedTensorType type; }; +struct RowInterval { + int64_t begin = 0; + int64_t end = 0; +}; + +struct ConvRowDemand { + RowInterval outputRows; + RowInterval neededInputRows; + RowInterval acquiredInputRows; + int64_t topHaloRows = 0; + int64_t bottomHaloRows = 0; +}; + struct ConvStrategyEstimate { uint64_t estimatedMvmCount = 0; uint64_t estimatedReductionVAddCount = 0; @@ -270,6 +285,12 @@ struct DistributedConvReportTotals { SmallVector chains; }; +static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter); +static FailureOr createRowStripPackedRows(Value rows, + const ConvLoweringState& state, + PatternRewriter& rewriter, + Location loc); + static bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup); static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor); @@ -542,7 +563,37 @@ classifyDistributedBinaryConsumer(Operation* user, return std::nullopt; } -[[maybe_unused]] static bool canConsumeDistributedConvInput(const ConvLoweringState& state, StringRef& failureReason) { +static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, + int64_t inputHeight, + int64_t kernelH, + int64_t strideH, + int64_t dilationH, + int64_t padTop) { + const int64_t rawBegin = outputRows.begin * strideH - padTop; + const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1; + return {std::max(0, rawBegin), std::min(inputHeight, rawEnd)}; +} + +static bool covers(RowInterval acquired, RowInterval needed) { + return acquired.begin <= needed.begin && acquired.end >= needed.end; +} + +static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) { + const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin; + const int64_t rawEnd = + (outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1; + RowInterval neededInputRows = computeConvInputRowsForOutputRows( + outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin); + ConvRowDemand demand; + demand.outputRows = outputRows; + demand.neededInputRows = neededInputRows; + demand.acquiredInputRows = neededInputRows; + demand.topHaloRows = std::max(0, -rawBegin); + demand.bottomHaloRows = std::max(0, rawEnd - state.xHeight); + return demand; +} + +static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) { if (state.batchSize != 1) { failureReason = "unsupported_batch"; return false; @@ -551,6 +602,10 @@ classifyDistributedBinaryConsumer(Operation* user, failureReason = "unsupported_groups"; return false; } + if (isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup)) { + failureReason = "unsupported_depthwise"; + return false; + } if (state.strideHeight != 1 || state.strideWidth != 1) { failureReason = "unsupported_stride"; return false; @@ -600,7 +655,7 @@ static std::string stringifyDistributedTensorOpKind(DistributedTensorOpKind kind llvm_unreachable("unknown distributed tensor op kind"); } -static DistributedConvAnalysis analyzeDistributedConvConsumers(ONNXConvOp convOp) { +[[maybe_unused]] static DistributedConvAnalysis analyzeDistributedConvConsumers(ONNXConvOp convOp) { DistributedConvAnalysis analysis; analysis.replacementOp = convOp; @@ -740,7 +795,7 @@ static void rewriteDistributedConvReport(const DistributedConvReportTotals& tota } } -static void recordDistributedConvOutcome(const DistributedConvAnalysis& analysis) { +[[maybe_unused]] static void recordDistributedConvOutcome(const DistributedConvAnalysis& analysis) { static std::mutex reportMutex; static DistributedConvReportTotals totals; @@ -1030,7 +1085,7 @@ static void rewriteConvLoweringReport(ArrayRef entries) { writeConvReportTable(reportFile, entries); } -static FailureOr resolveRequestedConvLoweringStrategy(ONNXConvOp convOp) { +[[maybe_unused]] static FailureOr resolveRequestedConvLoweringStrategy(ONNXConvOp convOp) { if (!useExperimentalConvImpl) return pimConvLowering.getValue(); @@ -1092,7 +1147,7 @@ static ConvLoweringDecision chooseConvLoweringStrategy(const ConvGeometry& geo, return {PimConvLoweringTiled2D, "both reduction K and output channels exceed one crossbar", /*isAuto=*/true, "", ""}; } -static LogicalResult verifyForcedConvLoweringStrategy(ONNXConvOp convOp, +[[maybe_unused]] static LogicalResult verifyForcedConvLoweringStrategy(ONNXConvOp convOp, const ConvGeometry& geo, PimConvLoweringType strategy) { switch (strategy) { @@ -1196,12 +1251,11 @@ static void reportConvLoweringDecision(ONNXConvOp convOp, } static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) { - if (pimConvStreamChunkPositions.getNumOccurrences() != 0) - return std::max(1, pimConvStreamChunkPositions); - const uint64_t patchElements = static_cast(std::max(1, geo.k)); uint64_t chunkPositions = std::max(1, pimConvIm2colMaxElements / patchElements); chunkPositions = std::min(chunkPositions, static_cast(std::max(1, geo.p))); + chunkPositions = std::min(chunkPositions, std::max(1, pimConvStreamChunkPositions)); + if (packFactor > 1 && chunkPositions > static_cast(packFactor)) { chunkPositions -= chunkPositions % static_cast(packFactor); chunkPositions = std::max(chunkPositions, static_cast(packFactor)); @@ -1209,7 +1263,7 @@ static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t pack return std::max(1, chunkPositions); } -static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { +static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) return bias; @@ -1242,7 +1296,7 @@ static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, ArrayRef lowPadValues, ArrayRef highPadValues, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto valueType = cast(value.getType()); if (valueType == resultType) @@ -1271,7 +1325,7 @@ static Value createZeroPaddedTensor(Value value, } static Value affineAddConst( - ConversionPatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) { + PatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) { if (offset == 0) return value; @@ -1288,7 +1342,7 @@ static Value createConvInputPatch(Value input, Value inputWidthOffset, int64_t dilationHeight, int64_t dilationWidth, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { const int64_t patchChannels = patchType.getDimSize(1); const int64_t kernelHeight = patchType.getDimSize(2); @@ -1335,7 +1389,7 @@ static Value createCollectedConvOutput(ValueRange gemmRows, int64_t numChannelsOut, int64_t packFactor, ArrayRef distributedConsumers, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc); static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); @@ -1390,7 +1444,7 @@ static std::optional computeTiling(int64_t batchSize, static Value buildPackedWeights(DenseElementsAttr wDenseAttr, RankedTensorType wType, const Tiling& tiling, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto packedWeightType = RankedTensorType::get( {tiling.numChannelTiles, tiling.tileInputRows, tiling.tileOutputChannels}, wType.getElementType()); @@ -1429,7 +1483,7 @@ static Value createPaddedInput(Value input, int64_t padHeightEnd, int64_t padWidthBegin, int64_t padWidthEnd, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (padHeightBegin == 0 && padHeightEnd == 0 && padWidthBegin == 0 && padWidthEnd == 0) return input; @@ -1461,7 +1515,7 @@ static Value createInputTile(Value input, int64_t dilationHeight, int64_t dilationWidth, int64_t outWidth, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value batchIndex = affineFloorDivConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); @@ -1500,7 +1554,7 @@ static Value createWeightTile(Value packedWeights, Value channelTileIndex, RankedTensorType packedWeightType, const Tiling& tiling, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { SmallVector offsets {channelTileIndex, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(1), @@ -1523,7 +1577,7 @@ static Value createWeightTile(Value packedWeights, } static Value createBiasTile( - Value bias, Value channelTileIndex, const Tiling& tiling, ConversionPatternRewriter& rewriter, Location loc) { + Value bias, Value channelTileIndex, const Tiling& tiling, PatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); auto biasTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, biasType.getElementType()); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); @@ -1539,7 +1593,7 @@ static Value insertOutputTile(Value rowTile, Value rowAccumulator, Value channelTileIndex, const Tiling& tiling, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value channelOffset = tiling.tileOutputChannels == 1 @@ -1555,7 +1609,7 @@ static FailureOr reconstructDepthwiseGemmRows(Value pieces, RankedTensorType piecesType, RankedTensorType gemmOutType, const Tiling& tiling, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto collectedOp = createSpatCompute<1>(rewriter, loc, TypeRange {gemmOutType}, {}, pieces, [&](Value piecesArg) { auto rowType = RankedTensorType::get({1, gemmOutType.getDimSize(1)}, gemmOutType.getElementType()); @@ -1651,7 +1705,7 @@ static bool canUseStructuredRewrite(const ConvLoweringState& state) { if (tiling->numChannelTiles > static_cast(crossbarCountInCore.getValue())) return false; - if (isa(state.b.getDefiningOp())) + if (!state.hasBias) return true; auto biasType = dyn_cast(state.b.getType()); @@ -1665,10 +1719,10 @@ static bool canUseStructuredRewrite(const ConvLoweringState& state) { } static FailureOr -rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter, Location loc) { +rewriteConv(Operation* convOp, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); if (!wDenseAttr) { - convOp.emitOpError("requires constant-derived weights for structured depthwise Spatial lowering"); + convOp->emitOpError("requires constant-derived weights for structured depthwise Spatial lowering"); return failure(); } @@ -1680,7 +1734,7 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern state.outType.getDimSize(2), state.outType.getDimSize(3)); if (!tiling) { - convOp.emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering"); + convOp->emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering"); return failure(); } @@ -1696,12 +1750,12 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern Value expandedBias; SmallVector batchInputs {paddedInput}; - if (!isa(state.b.getDefiningOp())) { + if (state.hasBias) { expandedBias = expandBiasIfNeeded(state.b, rewriter, loc); auto biasType = dyn_cast(expandedBias.getType()); if (!biasType || biasType.getRank() != 2 || biasType.getDimSize(0) != 1 || biasType.getDimSize(1) != state.outType.getDimSize(1)) { - convOp.emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering"); + convOp->emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering"); return failure(); } batchInputs.push_back(expandedBias); @@ -1755,7 +1809,7 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern : affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); Value paddedInputArg = pickInputByRank(/*rank=*/4); if (!paddedInputArg) { - convOp.emitOpError("structured depthwise batch body requires a rank-4 padded input block argument"); + convOp->emitOpError("structured depthwise batch body requires a rank-4 padded input block argument"); return failure(); } @@ -1784,7 +1838,7 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern if (args.inputs.size() > 1) { Value biasArg = pickInputByRank(/*rank=*/2); if (!biasArg) { - convOp.emitOpError("structured depthwise batch body requires a rank-2 bias block argument when bias is present"); + convOp->emitOpError("structured depthwise batch body requires a rank-2 bias block argument when bias is present"); return failure(); } Value biasTile = tiling->numChannelTiles == 1 ? biasArg : createBiasTile(biasArg, channelTileIndex, *tiling, rewriter, loc); @@ -1858,7 +1912,7 @@ buildConvGemmPlan(const ConvLoweringState& state, std::optional forcedPackFactor = std::nullopt); static PreparedConvInput prepareInputForIm2Col(const ConvLoweringState& state, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (state.padHeightBegin == 0 && state.padHeightEnd == 0 && state.padWidthBegin == 0 && state.padWidthEnd == 0) return {state.x, state.xType}; @@ -1884,7 +1938,7 @@ static PreparedConvInput prepareInputForIm2Col(const ConvLoweringState& state, static Value createPaddedRows(Value rows, RankedTensorType rowsType, int64_t paddedRows, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (rowsType.getDimSize(0) == paddedRows) return rows; @@ -1896,7 +1950,7 @@ static Value createPaddedRows(Value rows, } static Value packRowsForParallelGemm( - Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { + Value rows, RankedTensorType rowsType, int64_t packFactor, PatternRewriter& rewriter, Location loc) { if (packFactor == 1) return rows; @@ -1932,7 +1986,7 @@ static Value unpackRowsFromParallelGemm(Value packedRows, int64_t unpackedRows, int64_t rowWidth, int64_t packFactor, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (packFactor == 1) return packedRows; @@ -1971,7 +2025,7 @@ static Value unpackRowsFromParallelGemm(Value packedRows, } static Value createWeightMatrix( - Value weights, const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { + Value weights, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) { auto buildWeightMatrix = [&](Value weight) -> Value { Value flattened = tensor::CollapseShapeOp::create(rewriter, loc, @@ -1998,7 +2052,7 @@ static Value createWeightMatrix( static Value createPaddedConvMatrix(Value matrix, RankedTensorType sourceType, RankedTensorType paddedType, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (sourceType == paddedType) return matrix; @@ -2014,7 +2068,7 @@ static Value createPaddedConvMatrix(Value matrix, static Value createPaddedConstantMatrix(DenseElementsAttr sourceAttr, RankedTensorType sourceType, RankedTensorType paddedType, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { SmallVector paddedValues( paddedType.getNumElements(), cast(rewriter.getZeroAttr(paddedType.getElementType()))); SmallVector sourceValues(sourceAttr.getValues()); @@ -2032,7 +2086,7 @@ static Value createPaddedInputKTiledWeightConstant(DenseElementsAttr sourceAttr, const ConvLoweringState& state, int64_t paddedK, int64_t paddedC, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { auto paddedType = RankedTensorType::get({paddedK, paddedC}, state.wType.getElementType()); SmallVector sourceValues(sourceAttr.getValues()); SmallVector paddedValues( @@ -2055,7 +2109,7 @@ static Value createPaddedInputKTiledWeightConstant(DenseElementsAttr sourceAttr, static FailureOr rewriteInputKTiledConv(const ConvLoweringState& state, ArrayRef distributedConsumers, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); ConvGeometry geo = buildConvGeometry(state); @@ -2248,7 +2302,7 @@ static Value buildPackedWeights(DenseElementsAttr wDenseAttr, Value wTrans, const ConvLoweringState& state, const ConvGemmPlan& plan, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (plan.effectiveMaxParallelPixels == 1) return wTrans; @@ -2286,7 +2340,7 @@ static Value buildPackedBias(Value gemmBias, DenseElementsAttr biasDenseAttr, const ConvLoweringState& state, const ConvGemmPlan& plan, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (!state.hasBias) return gemmBias; @@ -2346,7 +2400,7 @@ buildConvGemmPlan(const ConvLoweringState& state, static Value createIm2colRows(const ConvLoweringState& state, const PreparedConvInput& preparedInput, const ConvGemmPlan& plan, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { constexpr size_t numInputs = 1; auto im2colComputeOp = @@ -2434,7 +2488,7 @@ static Value createIm2colRows(const ConvLoweringState& state, static Value maybeUnpackChunkRows(Value gemmRows, const ConvGemmPlan& plan, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { if (plan.effectiveMaxParallelPixels == 1) return gemmRows; @@ -2456,13 +2510,12 @@ static Value maybeUnpackChunkRows(Value gemmRows, static Value createChunkedConvRows(const ConvLoweringState& state, const PreparedConvInput& preparedInput, Value weightMatrix, - Value gemmBias, Value biasMatrix, DenseElementsAttr wDenseAttr, DenseElementsAttr biasDenseAttr, int64_t forcedPackFactor, uint64_t chunkPositions, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { SmallVector chunkRows; const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth; @@ -2476,6 +2529,9 @@ static Value createChunkedConvRows(const ConvLoweringState& state, forcedPackFactor); Value chunkInputRows = createIm2colRows(state, preparedInput, chunkPlan, rewriter, loc); Value chunkB = buildPackedWeights(wDenseAttr, weightMatrix, state, chunkPlan, rewriter, loc); + Value gemmBias = createZeroGemmBias(chunkPlan.gemmOutputRowsType, rewriter); + if (state.hasBias) + gemmBias = state.b; Value chunkC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, chunkPlan, rewriter, loc); Value chunkGemmRows = ONNXGemmOp::create(rewriter, loc, @@ -2483,10 +2539,10 @@ static Value createChunkedConvRows(const ConvLoweringState& state, chunkInputRows, chunkB, chunkC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) + APFloat(1.0f), + APFloat(1.0f), + /*transA=*/0, + /*transB=*/0) .getY(); chunkRows.push_back(maybeUnpackChunkRows(chunkGemmRows, chunkPlan, rewriter, loc)); } @@ -2503,15 +2559,13 @@ static Value createChunkedConvRows(const ConvLoweringState& state, static Value rewritePackedIm2ColConv(const ConvLoweringState& state, ArrayRef distributedConsumers, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); - Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (state.hasBias) { - gemmBias = state.b; biasDenseAttr = getHostConstDenseElementsAttr(state.b); biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); } @@ -2524,6 +2578,9 @@ static Value rewritePackedIm2ColConv(const ConvLoweringState& state, Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc); Value gemmInputRows = createIm2colRows(state, preparedInput, plan, rewriter, loc); Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); + Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter); + if (state.hasBias) + gemmBias = state.b; Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); Value gemmRows = ONNXGemmOp::create(rewriter, @@ -2532,10 +2589,10 @@ static Value rewritePackedIm2ColConv(const ConvLoweringState& state, gemmInputRows, gemmB, gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) + APFloat(1.0f), + APFloat(1.0f), + /*transA=*/0, + /*transB=*/0) .getY(); return createCollectedConvOutput(ValueRange {gemmRows}, @@ -2553,16 +2610,14 @@ static Value rewritePackedIm2ColConv(const ConvLoweringState& state, static Value rewriteStreamedConv(const ConvLoweringState& state, ArrayRef distributedConsumers, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc, int64_t forcedPackFactor) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); - Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (state.hasBias) { - gemmBias = state.b; biasDenseAttr = getHostConstDenseElementsAttr(state.b); biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); } @@ -2575,7 +2630,6 @@ static Value rewriteStreamedConv(const ConvLoweringState& state, Value collectedRows = createChunkedConvRows(state, preparedInput, weightMatrix, - gemmBias, biasMatrix, wDenseAttr, biasDenseAttr, @@ -2629,7 +2683,7 @@ static DistributedTensorInfo makeDistributedTensorInfo(Value storage, RankedTens static Value createPerChannelConstantFragment(DenseElementsAttr denseAttr, RankedTensorType fragmentType, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { auto denseType = cast(denseAttr.getType()); SmallVector channelValues; channelValues.reserve(fragmentType.getDimSize(1)); @@ -2657,9 +2711,272 @@ static Value createPerChannelConstantFragment(DenseElementsAttr denseAttr, return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), attr, fragmentType); } +static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter) { + auto zeroAttr = DenseElementsAttr::get(gemmResultType, rewriter.getZeroAttr(gemmResultType.getElementType())); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), zeroAttr, gemmResultType); +} + +static bool canDirectLowerRowStripConv(const ConvLoweringState& state, StringRef& failureReason) { + if (!canConsumeRowStripHwcInput(state, failureReason)) + return false; + + ConvGeometry geometry = buildConvGeometry(state); + if (state.numChannelsOut > geometry.xbarSize) { + failureReason = "unsupported_output_channels"; + return false; + } + + failureReason = ""; + return true; +} + +static FailureOr createRowStripPackedRows(Value rows, + const ConvLoweringState& state, + PatternRewriter& rewriter, + Location loc) { + auto rowsType = dyn_cast(rows.getType()); + if (!rowsType || !rowsType.hasStaticShape() || rowsType.getRank() != 2) + return failure(); + + if (state.batchSize != 1) + return failure(); + if (state.outType.getRank() != 4 || !state.outType.hasStaticShape()) + return failure(); + + const int64_t outHeight = state.outType.getDimSize(2); + const int64_t outWidth = state.outType.getDimSize(3); + const int64_t outChannels = state.outType.getDimSize(1); + if (rowsType.getDimSize(0) != outHeight * outWidth || rowsType.getDimSize(1) != outChannels) + return failure(); + + auto packedType = RankedTensorType::get({outHeight, outWidth, outChannels}, rowsType.getElementType(), rowsType.getEncoding()); + auto packedRows = + createSpatCompute<1>(rewriter, loc, TypeRange {packedType}, {}, rows, [&](Value rowValues) { + Value packed = tensor::ExpandShapeOp::create( + rewriter, loc, packedType, rowValues, SmallVector {{0, 1}, {2}}); + spatial::SpatYieldOp::create(rewriter, loc, packed); + }); + return packedRows.getResult(0); +} + +static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, + const ConvLoweringState& state, + PatternRewriter& rewriter, + Location loc) { + auto inputType = dyn_cast(inputHwc.getType()); + if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 3) + return failure(); + if (inputType.getDimSize(0) != state.xHeight || inputType.getDimSize(1) != state.xWidth + || inputType.getDimSize(2) != state.numChannelsIn) + return failure(); + + StringRef failureReason; + if (!canDirectLowerRowStripConv(state, failureReason)) + return failure(); + + ConvRowDemand demand = buildConvRowDemand(RowInterval {0, state.outHeight}, state); + if (!covers(demand.acquiredInputRows, demand.neededInputRows)) + return failure(); + + ConvGeometry geometry = buildConvGeometry(state); + const int64_t xbarDim = geometry.xbarSize; + const int64_t patchSize = state.numChannelsIn * state.wHeight * state.wWidth; + const int64_t numKSlices = ceilIntegerDivide(patchSize, xbarDim); + const int64_t paddedK = numKSlices * xbarDim; + auto elementType = inputType.getElementType(); + auto paddedInputType = RankedTensorType::get({state.xHeight + state.padHeightBegin + state.padHeightEnd, + state.xWidth + state.padWidthBegin + state.padWidthEnd, + state.numChannelsIn}, + elementType, + inputType.getEncoding()); + auto paddedPatchType = + RankedTensorType::get({state.wHeight, state.wWidth, 1}, elementType, inputType.getEncoding()); + auto flatPatchType = RankedTensorType::get({state.wHeight * state.wWidth}, elementType, inputType.getEncoding()); + auto rowChunkType = RankedTensorType::get({1, state.wHeight * state.wWidth}, elementType, inputType.getEncoding()); + auto rowType = RankedTensorType::get({1, state.numChannelsOut}, state.outType.getElementType()); + auto packedOutputType = + RankedTensorType::get({state.outHeight, state.outWidth, state.numChannelsOut}, state.outType.getElementType()); + auto packedOutputSliceType = + RankedTensorType::get({1, 1, state.numChannelsOut}, state.outType.getElementType()); + auto paddedRowType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType()); + auto paddedPatchRowType = RankedTensorType::get({1, paddedK}, elementType, inputType.getEncoding()); + auto paddedWeightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType()); + auto weightDenseAttr = getHostConstDenseElementsAttr(state.w); + if (!weightDenseAttr) + return failure(); + Value paddedWeights = standard::createPaddedInputKTiledWeightConstant(weightDenseAttr, state, paddedK, xbarDim, rewriter); + + Value paddedBias; + if (state.hasBias) { + Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); + auto biasMatrixType = cast(biasMatrix.getType()); + auto paddedBiasType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType()); + if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b)) + paddedBias = standard::createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter); + else + paddedBias = materializeOrComputeUnary( + biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) { + return standard::createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc); + }); + } + + auto paddedInputOp = + createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, inputHwc, [&](Value hwcInputArg) { + Value paddedInput = createZeroPaddedTensor(hwcInputArg, + paddedInputType, + {state.padHeightBegin, state.padWidthBegin, 0}, + {state.padHeightEnd, state.padWidthEnd, 0}, + rewriter, + loc); + spatial::SpatYieldOp::create(rewriter, loc, paddedInput); + }); + + SmallVector batchInputs {paddedInputOp.getResult(0)}; + if (state.hasBias) + batchInputs.push_back(paddedBias); + auto batchOp = createSpatComputeBatch( + rewriter, + loc, + TypeRange {packedOutputType}, + state.outHeight, + ValueRange {paddedWeights}, + batchInputs, + [&](detail::SpatComputeBatchBodyArgs args) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices); + Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim); + Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); + Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn); + Value cPatchWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.wHeight * state.wWidth); + Value localHeightOffset = arith::MulIOp::create(rewriter, loc, args.lane, c1); + Value packedRowInit = + tensor::EmptyOp::create(rewriter, loc, ArrayRef {1, state.outWidth, state.numChannelsOut}, elementType); + auto widthLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cOutWidth, + c1, + ValueRange {packedRowInit}, + [&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl& widthYielded) { + Value localWidthOffset = arith::MulIOp::create(rewriter, widthLoc, widthIndex, c1); + Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef {1, patchSize}, elementType); + auto rowLoop = buildNormalizedScfFor( + rewriter, + widthLoc, + c0, + cNumChannels, + c1, + ValueRange {rowInit}, + [&](OpBuilder&, Location rowLoc, Value channel, ValueRange rowIterArgs, SmallVectorImpl& rowYielded) { + SmallVector patchOffsets {localHeightOffset, localWidthOffset, channel}; + SmallVector patchSizes { + rewriter.getIndexAttr(state.wHeight), rewriter.getIndexAttr(state.wWidth), rewriter.getIndexAttr(1)}; + Value channelPatch = tensor::ExtractSliceOp::create( + rewriter, rowLoc, paddedPatchType, args.inputs.front(), patchOffsets, patchSizes, getUnitStrides(rewriter, 3)); + Value flatPatch = tensor::CollapseShapeOp::create( + rewriter, rowLoc, flatPatchType, channelPatch, SmallVector {{0, 1, 2}}); + Value rowChunk = tensor::ExpandShapeOp::create( + rewriter, rowLoc, rowChunkType, flatPatch, SmallVector {{0, 1}}); + Value flatOffset = arith::MulIOp::create(rewriter, rowLoc, channel, cPatchWidth); + SmallVector rowOffsets {rewriter.getIndexAttr(0), flatOffset}; + SmallVector rowSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)}; + Value nextRow = tensor::InsertSliceOp::create( + rewriter, rowLoc, rowChunk, rowIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); + rowYielded.push_back(nextRow); + return success(); + }); + if (failed(rowLoop)) + return failure(); + + Value paddedRow = rowLoop->results.front(); + if (patchSize != paddedK) + paddedRow = createZeroPaddedTensor( + paddedRow, paddedPatchRowType, {0, 0}, {0, paddedK - patchSize}, rewriter, widthLoc); + + auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(state.outType.getElementType())); + Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType); + auto kLoop = buildNormalizedScfFor( + rewriter, + widthLoc, + c0, + cNumKSlices, + c1, + ValueRange {zeroRow}, + [&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl& reduceYielded) { + Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar); + SmallVector aOffsets {rewriter.getIndexAttr(0), kOffset}; + SmallVector aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)}; + Value aTile = tensor::ExtractSliceOp::create( + rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2)); + SmallVector bOffsets {kOffset, rewriter.getIndexAttr(0)}; + SmallVector bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)}; + Value bTile = tensor::ExtractSliceOp::create( + rewriter, reduceLoc, paddedWeightTileType, args.weights.front(), bOffsets, bSizes, getUnitStrides(rewriter, 2)); + Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult(); + reduceYielded.push_back( + spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult()); + return success(); + }); + if (failed(kLoop)) + return failure(); + + Value rowResult = kLoop->results.front(); + if (state.hasBias) + rowResult = + spatial::SpatVAddOp::create(rewriter, widthLoc, paddedRowType, rowResult, args.inputs[1]).getResult(); + + Value outputRow = rowResult; + if (state.numChannelsOut != xbarDim) { + SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector outputSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; + outputRow = tensor::ExtractSliceOp::create( + rewriter, widthLoc, rowType, rowResult, outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); + } + + Value outputFragment = tensor::ExpandShapeOp::create(rewriter, + widthLoc, + packedOutputSliceType, + outputRow, + SmallVector {{0}, {1, 2}}); + SmallVector rowOffsets {rewriter.getIndexAttr(0), widthIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; + Value nextPackedRow = tensor::InsertSliceOp::create( + rewriter, widthLoc, outputFragment, widthIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 3)); + widthYielded.push_back(nextPackedRow); + return success(); + }); + if (failed(widthLoop)) + return failure(); + + SmallVector batchOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector batchSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.outWidth), rewriter.getIndexAttr(state.numChannelsOut)}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, widthLoop->results.front(), args.outputs.front(), batchOffsets, batchSizes, getUnitStrides(rewriter, 3)); + return success(); + }); + if (failed(batchOp)) + return failure(); + return batchOp->getResult(0); +} + +static FailureOr createConvRowsFromRowStripInput(const ConvLoweringState& state, + [[maybe_unused]] const ConvLoweringDecision& decision, + Value rowStripInput, + PatternRewriter& rewriter, + Location loc) { + return createConvOutputFromRowStripHwc(rowStripInput, state, rewriter, loc); +} + static Value createFragmentConstant(const DistributedTensorStep& step, RankedTensorType fragmentType, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { if (step.constantKind == DistributedTensorConstantKind::PerChannel) return createPerChannelConstantFragment(step.constantAttr, fragmentType, rewriter); @@ -2672,7 +2989,7 @@ static Value createFragmentConstant(const DistributedTensorStep& step, static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, RankedTensorType fragmentType, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { SmallVector values; if (step.constantKind == DistributedTensorConstantKind::PerChannel) { auto denseType = cast(step.constantAttr.getType()); @@ -2707,15 +3024,13 @@ static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, [[maybe_unused]] static FailureOr createConvRowsForStrategy(const ConvLoweringState& state, const ConvLoweringDecision& decision, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); PreparedConvInput preparedInput = standard::prepareInputForIm2Col(state, rewriter, loc); - Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (state.hasBias) { - gemmBias = state.b; biasDenseAttr = getHostConstDenseElementsAttr(state.b); biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); } @@ -2729,6 +3044,9 @@ static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, Value weightMatrix = standard::createWeightMatrix(state.w, plan, rewriter, loc); Value gemmInputRows = standard::createIm2colRows(state, preparedInput, plan, rewriter, loc); Value gemmB = standard::buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); + Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter); + if (state.hasBias) + gemmBias = state.b; Value gemmC = standard::buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); Value gemmRows = ONNXGemmOp::create(rewriter, loc, @@ -2736,10 +3054,10 @@ static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, gemmInputRows, gemmB, gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) + APFloat(1.0f), + APFloat(1.0f), + /*transA=*/0, + /*transB=*/0) .getY(); return standard::maybeUnpackChunkRows(gemmRows, plan, rewriter, loc); } @@ -2757,7 +3075,6 @@ static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, return standard::createChunkedConvRows(state, preparedInput, weightMatrix, - gemmBias, biasMatrix, wDenseAttr, biasDenseAttr, @@ -2773,7 +3090,7 @@ static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, [[maybe_unused]] static FailureOr createDistributedTensorFromRows(Value rows, RankedTensorType logicalType, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { const int64_t width = logicalType.getDimSize(3); const int64_t height = logicalType.getDimSize(2); @@ -2813,7 +3130,7 @@ static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, [[maybe_unused]] static FailureOr applyDistributedPreservingStep(const DistributedTensorInfo& inputInfo, const DistributedTensorStep& step, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto logicalType = inputInfo.logicalType; const int64_t width = logicalType.getDimSize(3); @@ -2885,7 +3202,7 @@ static Value createCollectedConvOutput(ValueRange gemmRows, int64_t numChannelsOut, int64_t packFactor, ArrayRef distributedConsumers, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, Location loc) { auto materializeSplatTensor = [&](DenseElementsAttr denseAttr, RankedTensorType targetType) { Attribute splatValue = denseAttr.getSplatValue(); @@ -3026,7 +3343,7 @@ static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, state.wWidth = state.wType.getDimSize(3); state.outHeight = state.outType.getDimSize(2); state.outWidth = state.outType.getDimSize(3); - state.hasBias = !isa(state.b.getDefiningOp()); + state.hasBias = state.b && !isa(state.b.getDefiningOp()); if (state.numChannelsIn % state.group != 0) { convOp.emitOpError() << "requires input channels " << state.numChannelsIn << " to be divisible by group " @@ -3123,34 +3440,221 @@ static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, return analyzeConvLoweringState(convOp, convOpAdaptor.getX(), convOpAdaptor.getW(), convOpAdaptor.getB()); } +static FailureOr analyzeConvLoweringState(spatial::SpatConv2DPlanOp planOp) { + ConvLoweringState state; + state.x = planOp.getInput(); + state.w = planOp.getWeight(); + state.b = planOp.getBias() ? planOp.getBias() : Value(); + state.xType = dyn_cast(state.x.getType()); + state.wType = dyn_cast(state.w.getType()); + state.outType = dyn_cast(planOp.getOutput().getType()); + + if (!state.xType || !state.wType || !state.outType) + return planOp.emitOpError("requires ranked tensor input, weight, and output"), failure(); + if (!state.xType.hasStaticShape() || !state.wType.hasStaticShape() || !state.outType.hasStaticShape()) + return planOp.emitOpError("requires static input, weight, and output shapes"), failure(); + if (state.xType.getRank() != 4 || state.wType.getRank() != 4 || state.outType.getRank() != 4) + return planOp.emitOpError("requires rank-4 input, weight, and output tensors"), failure(); + + state.group = planOp.getGroup(); + if (state.group < 1) + return planOp.emitOpError("requires group >= 1"), failure(); + + state.batchSize = state.xType.getDimSize(0); + state.numChannelsIn = state.xType.getDimSize(1); + state.xHeight = state.xType.getDimSize(2); + state.xWidth = state.xType.getDimSize(3); + state.numChannelsOut = state.wType.getDimSize(0); + state.wHeight = state.wType.getDimSize(2); + state.wWidth = state.wType.getDimSize(3); + state.outHeight = state.outType.getDimSize(2); + state.outWidth = state.outType.getDimSize(3); + state.hasBias = static_cast(planOp.getBias()); + + if (state.numChannelsIn % state.group != 0 || state.numChannelsOut % state.group != 0) + return planOp.emitOpError("requires input and output channels divisible by group"), failure(); + + state.numChannelsInPerGroup = state.numChannelsIn / state.group; + state.numChannelsOutPerGroup = state.numChannelsOut / state.group; + if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) + return planOp.emitOpError("requires grouped conv weight channels to match input channels per group"), failure(); + + auto pads = planOp.getPads(); + auto strides = planOp.getStrides(); + auto dilations = planOp.getDilations(); + if (pads.size() != 4 || strides.size() != 2 || dilations.size() != 2) + return planOp.emitOpError("requires 4 pads, 2 strides, and 2 dilations"), failure(); + + state.padHeightBegin = pads[0]; + state.padWidthBegin = pads[1]; + state.padHeightEnd = pads[2]; + state.padWidthEnd = pads[3]; + state.strideHeight = strides[0]; + state.strideWidth = strides[1]; + state.dilationHeight = dilations[0]; + state.dilationWidth = dilations[1]; + return state; +} + +static FailureOr resolveRequestedConvLoweringStrategy(Operation* op) { + if (!useExperimentalConvImpl) + return pimConvLowering.getValue(); + + if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) { + op->emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering=" + << stringifyConvLoweringStrategy(pimConvLowering); + return failure(); + } + return PimConvLoweringPackedIm2Col; +} + +static LogicalResult verifyForcedConvLoweringStrategy(Operation* op, + const ConvGeometry& geo, + PimConvLoweringType strategy) { + switch (strategy) { + case PimConvLoweringAuto: + case PimConvLoweringLegacy: + return success(); + case PimConvLoweringDepthwise: + if (geo.isDepthwise) + return success(); + return op->emitOpError("forced depthwise Conv lowering requires a depthwise convolution"); + case PimConvLoweringPackedIm2Col: + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements) + return success(); + return op->emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget"); + case PimConvLoweringStreamedPatch: + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize) + return success(); + return op->emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar"); + case PimConvLoweringStreamedPacked: + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2) + return success(); + return op->emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2"); + case PimConvLoweringOutputChannelTiled: + if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize) + return success(); + return op->emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X"); + case PimConvLoweringInputKTiled: + if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) + return success(); + return op->emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X"); + case PimConvLoweringTiled2D: + if (geo.k > geo.xbarSize && geo.c > geo.xbarSize) + return success(); + return op->emitOpError("forced tiled-2d Conv lowering requires K > X and C > X"); + } + llvm_unreachable("unknown conv lowering strategy"); +} + +static FailureOr lowerDenseSelectedConvPlan(Operation* op, + const ConvLoweringState& state, + PimConvLoweringType strategy, + PatternRewriter& rewriter, + Location loc); + +static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent, + Value groupX, + Value groupW, + Value groupB, + RankedTensorType groupOutType); + +static FailureOr buildConvValueForStrategy(Operation* op, + Location loc, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + const DistributedConvAnalysis& analysis, + ArrayRef distributedConsumers, + PatternRewriter& rewriter); + +static FailureOr buildGroupedConvValue(Operation* op, + Location loc, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + PatternRewriter& rewriter); + +static FailureOr lowerGroupedSelectedConvPlan(Operation* op, + const ConvLoweringState& state, + PimConvLoweringType strategy, + PatternRewriter& rewriter, + Location loc) { + ConvLoweringDecision decision {strategy, "", false, "", ""}; + return buildGroupedConvValue(op, loc, state, decision, rewriter); +} + +static FailureOr lowerDenseSelectedConvPlan(Operation* op, + const ConvLoweringState& state, + PimConvLoweringType strategy, + PatternRewriter& rewriter, + Location loc) { + DistributedConvAnalysis analysis; + analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; + analysis.barrierDetail = "selected dense layout"; + ConvLoweringDecision decision {strategy, "", false, "", ""}; + return buildConvValueForStrategy(op, loc, state, decision, analysis, {}, rewriter); +} + +static FailureOr buildConvValueForStrategy(Operation* op, + Location loc, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + const DistributedConvAnalysis& analysis, + ArrayRef distributedConsumers, + PatternRewriter& rewriter) { + (void)analysis; + const ConvGeometry geo = buildConvGeometry(state); + switch (decision.strategy) { + case PimConvLoweringDepthwise: { + return depthwise::rewriteConv(op, state, rewriter, loc); + } + case PimConvLoweringLegacy: + case PimConvLoweringPackedIm2Col: { + return standard::rewritePackedIm2ColConv(state, distributedConsumers, rewriter, loc); + } + case PimConvLoweringStreamedPatch: + case PimConvLoweringOutputChannelTiled: + case PimConvLoweringTiled2D: { + return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, /*forcedPackFactor=*/1); + } + case PimConvLoweringInputKTiled: { + return standard::rewriteInputKTiledConv(state, distributedConsumers, rewriter, loc); + } + case PimConvLoweringStreamedPacked: { + return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, geo.pack); + } + case PimConvLoweringAuto: + break; + } + op->emitOpError("unexpected auto strategy at Conv lowering dispatch"); + return failure(); +} + static LogicalResult createConvValueForStrategy(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, ArrayRef distributedConsumers, - ConversionPatternRewriter& rewriter, + PatternRewriter& rewriter, FailureOr& result) { + result = buildConvValueForStrategy(convOp, convOp.getLoc(), state, decision, analysis, distributedConsumers, rewriter); + if (failed(result)) + return failure(); + const ConvGeometry geo = buildConvGeometry(state); const ConvStrategyEstimate estimate = estimateConvStrategy(geo, decision.strategy, analysis); switch (decision.strategy) { - case PimConvLoweringDepthwise: { - result = depthwise::rewriteConv(convOp, state, rewriter, convOp.getLoc()); - if (failed(result)) - return failure(); + case PimConvLoweringDepthwise: reportConvLoweringDecision( convOp, geo, decision, estimate, /*batchSize=*/geo.p, /*numberOfBatches=*/1, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, std::nullopt); return success(); - } case PimConvLoweringLegacy: - case PimConvLoweringPackedIm2Col: { + case PimConvLoweringPackedIm2Col: reportConvLoweringDecision( convOp, geo, decision, estimate, /*batchSize=*/geo.pack, /*numberOfBatches=*/1, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, std::nullopt); - result = standard::rewritePackedIm2ColConv(state, distributedConsumers, rewriter, convOp.getLoc()); return success(); - } case PimConvLoweringStreamedPatch: case PimConvLoweringOutputChannelTiled: case PimConvLoweringTiled2D: { @@ -3165,7 +3669,6 @@ createConvValueForStrategy(ONNXConvOp convOp, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, chunkPositions); - result = standard::rewriteStreamedConv(state, distributedConsumers, rewriter, convOp.getLoc(), /*forcedPackFactor=*/1); return success(); } case PimConvLoweringInputKTiled: { @@ -3190,7 +3693,6 @@ createConvValueForStrategy(ONNXConvOp convOp, /*usesComputeBatch=*/false, /*usesBatchedInstructionEmission=*/false, rowChunkWidth); - result = standard::rewriteInputKTiledConv(state, distributedConsumers, rewriter, convOp.getLoc()); return success(); } case PimConvLoweringStreamedPacked: { @@ -3205,7 +3707,6 @@ createConvValueForStrategy(ONNXConvOp convOp, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, chunkPositions); - result = standard::rewriteStreamedConv(state, distributedConsumers, rewriter, convOp.getLoc(), geo.pack); return success(); } case PimConvLoweringAuto: @@ -3219,7 +3720,7 @@ rewriteSelectedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { FailureOr result = failure(); if (failed(createConvValueForStrategy(convOp, state, decision, analysis, analysis.steps, rewriter, result))) return failure(); @@ -3238,12 +3739,12 @@ rewriteSelectedConv(ONNXConvOp convOp, return success(); } -static LogicalResult +[[maybe_unused]] static LogicalResult rewriteUngroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, - ConversionPatternRewriter& rewriter) { + PatternRewriter& rewriter) { return rewriteSelectedConv(convOp, state, decision, analysis, rewriter); } @@ -3251,7 +3752,13 @@ static LogicalResult rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, - ConversionPatternRewriter& rewriter); + PatternRewriter& rewriter); + +static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent, + Value groupX, + Value groupW, + Value groupB, + RankedTensorType groupOutType); static ConvLoweringState makeGroupedConvLoweringState( const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) { @@ -3274,18 +3781,17 @@ static ConvLoweringState makeGroupedConvLoweringState( state.group = 1; state.numChannelsInPerGroup = state.numChannelsIn; state.numChannelsOutPerGroup = state.numChannelsOut; - state.hasBias = !isa(groupB.getDefiningOp()); + state.hasBias = static_cast(groupB); return state; } -static LogicalResult -rewriteGroupedConv(ONNXConvOp convOp, - const ConvLoweringState& state, - const ConvLoweringDecision& decision, - ConversionPatternRewriter& rewriter) { - SmallVector xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, convOp.getLoc()); - SmallVector wSlices = - sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); +static FailureOr buildGroupedConvValue(Operation* op, + Location loc, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + PatternRewriter& rewriter) { + SmallVector xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, loc); + SmallVector wSlices = sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, loc); SmallVector bSlices; if (state.hasBias) { auto biasType = cast(state.b.getType()); @@ -3295,16 +3801,16 @@ rewriteGroupedConv(ONNXConvOp convOp, else if (biasType.getRank() == 2) biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1; else { - convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank " - << biasType.getRank(); + op->emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank " + << biasType.getRank(); return failure(); } - bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); + bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, loc); } if (xSlices.size() != static_cast(state.group) || wSlices.size() != static_cast(state.group) || (state.hasBias && bSlices.size() != static_cast(state.group))) { - convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering"); + op->emitOpError("failed to partition grouped convolution operands for Spatial lowering"); return failure(); } @@ -3312,35 +3818,39 @@ rewriteGroupedConv(ONNXConvOp convOp, groupResults.reserve(state.group); auto groupOutType = RankedTensorType::get( {state.batchSize, state.numChannelsOutPerGroup, state.outHeight, state.outWidth}, state.outType.getElementType()); - Value noBias = ONNXNoneOp::create(rewriter, convOp.getLoc(), rewriter.getNoneType()); for (int64_t groupId = 0; groupId < state.group; groupId++) { Value groupX = xSlices[groupId]; Value groupW = wSlices[groupId]; - Value groupB = state.hasBias ? bSlices[groupId] : noBias; + Value groupB = state.hasBias ? bSlices[groupId] : Value(); ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType); - FailureOr groupResult = failure(); DistributedConvAnalysis groupAnalysis; groupAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv; groupAnalysis.barrierDetail = "grouped convolution still materializes densely"; - if (failed(createConvValueForStrategy(convOp, groupState, decision, groupAnalysis, {}, rewriter, groupResult))) + FailureOr groupResult = + buildConvValueForStrategy(op, loc, groupState, decision, groupAnalysis, {}, rewriter); + if (failed(groupResult)) return failure(); groupResults.push_back(*groupResult); } - Value result; - if (llvm::all_of(groupResults, isCompileTimeComputable)) { - result = createSpatConcat(rewriter, convOp.getLoc(), /*axis=*/1, groupResults); - } - else { - auto concatCompute = - createSpatCompute(rewriter, convOp.getLoc(), TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) { - spatial::SpatYieldOp::create( - rewriter, convOp.getLoc(), createSpatConcat(rewriter, convOp.getLoc(), /*axis=*/1, args)); - }); - result = concatCompute.getResult(0); - } + if (llvm::all_of(groupResults, isCompileTimeComputable)) + return createSpatConcat(rewriter, loc, /*axis=*/1, groupResults); - rewriter.replaceOp(convOp, result); + auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args)); + }); + return concatCompute.getResult(0); +} + +[[maybe_unused]] static LogicalResult +rewriteGroupedConv(ONNXConvOp convOp, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + PatternRewriter& rewriter) { + FailureOr result = buildGroupedConvValue(convOp.getOperation(), convOp.getLoc(), state, decision, rewriter); + if (failed(result)) + return failure(); + rewriter.replaceOp(convOp, *result); return success(); } @@ -3352,21 +3862,47 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, FailureOr state = analyzeConvLoweringState(convOp, convOpAdaptor); if (failed(state)) return failure(); + SmallVector pads { + state->padHeightBegin, state->padWidthBegin, state->padHeightEnd, state->padWidthEnd}; + SmallVector strides {state->strideHeight, state->strideWidth}; + SmallVector dilations {state->dilationHeight, state->dilationWidth}; + Value bias = state->hasBias ? convOpAdaptor.getB() : Value(); + auto convPlan = spatial::SpatConv2DPlanOp::create(rewriter, + convOp.getLoc(), + convOp.getY().getType(), + convOpAdaptor.getX(), + convOpAdaptor.getW(), + bias, + rewriter.getDenseI64ArrayAttr(pads), + rewriter.getDenseI64ArrayAttr(strides), + rewriter.getDenseI64ArrayAttr(dilations), + rewriter.getI64IntegerAttr(state->group), + rewriter.getStringAttr("nchw")); + rewriter.replaceOp(convOp, convPlan.getResult()); + return success(); +} - FailureOr requestedStrategy = resolveRequestedConvLoweringStrategy(convOp); +void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } + +LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp) { + FailureOr state = analyzeConvLoweringState(planOp); + if (failed(state)) + return failure(); + + if (state->group != 1 || state->batchSize != 1) + return failure(); + if (state->outType.getRank() != 4 || !state->outType.hasStaticShape()) + return failure(); + + FailureOr requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation()); if (failed(requestedStrategy)) return failure(); - DistributedConvAnalysis distributedAnalysis; - if (state->group == 1) - distributedAnalysis = analyzeDistributedConvConsumers(convOp); - else { - distributedAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv; - distributedAnalysis.barrierDetail = "grouped convolution still materializes densely"; - } - + DistributedConvAnalysis analysis; + analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; + analysis.barrierDetail = "selected row-strip layout"; ConvGeometry geometry = buildConvGeometry(*state); - ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, distributedAnalysis); + ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis); if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state) && *requestedStrategy == PimConvLoweringAuto) { decision = {PimConvLoweringLegacy, @@ -3375,28 +3911,85 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, "", ""}; } - if (failed(verifyForcedConvLoweringStrategy(convOp, geometry, decision.strategy))) + if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy))) return failure(); - if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state)) - return convOp.emitOpError("selected depthwise Conv lowering requires constant-derived weights and sliceable bias"); - - if (decision.strategy == PimConvLoweringDepthwise) { - distributedAnalysis.barrierKind = DistributedConvBarrierKind::Depthwise; - distributedAnalysis.barrierDetail = "depthwise lowering still materializes densely"; - recordDistributedConvOutcome(distributedAnalysis); - return rewriteSelectedConv(convOp, *state, decision, distributedAnalysis, rewriter); + switch (decision.strategy) { + case PimConvLoweringLegacy: + case PimConvLoweringPackedIm2Col: + case PimConvLoweringStreamedPatch: + case PimConvLoweringOutputChannelTiled: + case PimConvLoweringTiled2D: + case PimConvLoweringStreamedPacked: + return success(); + case PimConvLoweringAuto: + case PimConvLoweringDepthwise: + case PimConvLoweringInputKTiled: + return failure(); } - - if (state->group == 1) { - recordDistributedConvOutcome(distributedAnalysis); - return rewriteUngroupedConv(convOp, *state, decision, distributedAnalysis, rewriter); - } - - recordDistributedConvOutcome(distributedAnalysis); - return rewriteGroupedConv(convOp, *state, decision, rewriter); + llvm_unreachable("unknown conv lowering strategy"); } -void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } +LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp) { + FailureOr state = analyzeConvLoweringState(planOp); + if (failed(state)) + return failure(); + + StringRef failureReason; + return canDirectLowerRowStripConv(*state, failureReason) ? success() : failure(); +} + +FailureOr +lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp, + std::optional rowStripInput, + bool emitRowStripLayout, + PatternRewriter& rewriter) { + FailureOr state = analyzeConvLoweringState(planOp); + if (failed(state)) + return failure(); + + FailureOr requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation()); + if (failed(requestedStrategy)) + return failure(); + + DistributedConvAnalysis analysis; + analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; + analysis.barrierDetail = emitRowStripLayout ? "selected row-strip layout" : "selected dense layout"; + ConvGeometry geometry = buildConvGeometry(*state); + ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis); + if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state) + && *requestedStrategy == PimConvLoweringAuto) { + decision = {PimConvLoweringLegacy, + "depthwise auto fallback when structured depthwise lowering is not representable", + /*isAuto=*/true, + "", + ""}; + } + if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy))) + return failure(); + + if (emitRowStripLayout) { + if (rowStripInput) { + if (failed(canConsumeAndProduceRowStrip(planOp))) + return planOp.emitOpError("selected row-strip input/output layout is not supported for this Conv plan"), failure(); + return createConvRowsFromRowStripInput(*state, decision, *rowStripInput, rewriter, planOp.getLoc()); + } + if (failed(canLowerConvPlanToRowStrip(planOp))) + return planOp.emitOpError("selected row-strip layout is not supported for this Conv plan"), failure(); + FailureOr rows = createConvRowsForStrategy(*state, decision, rewriter, planOp.getLoc()); + if (failed(rows)) + return failure(); + FailureOr packedRows = createRowStripPackedRows(*rows, *state, rewriter, planOp.getLoc()); + if (failed(packedRows)) + return planOp.emitOpError("failed to pack Conv rows into the selected row-strip physical layout"), failure(); + return *packedRows; + } + + if (decision.strategy == PimConvLoweringDepthwise) + return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc()); + if (state->group != 1) + return lowerGroupedSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc()); + return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc()); +} } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp index 9f256b7..008865e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp @@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern { matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Location loc = reluOp.getLoc(); Type resultType = reluOp.getResult().getType(); - constexpr size_t numInputs = 1; - auto computeOp = createSpatCompute(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) { - auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x); - spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult()); - }); - rewriter.replaceOp(reluOp, computeOp); + auto reluPlan = spatial::SpatReluPlanOp::create( + rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw")); + rewriter.replaceOp(reluOp, reluPlan.getResult()); return success(); } }; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp index d180ce5..a66b456 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp @@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute, } // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. -struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override { auto promoted = computePromotedOperands(compute); if (failed(promoted)) return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote"); Block& oldBlock = compute.getBody().front(); rewriter.setInsertionPointAfter(compute); - auto newCompute = spatial::SpatCompute::create( + auto newCompute = spatial::SpatGraphCompute::create( rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs); SmallVector newBlockArgTypes; SmallVector newBlockArgLocs; @@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override { auto promoted = computePromotedOperands(compute); if (failed(promoted)) return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote"); @@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(compute.getLaneCount()), "promoted compute_batch lane count"); if (failed(laneCountAttr)) return failure(); - auto newCompute = spatial::SpatComputeBatch::create( + auto newCompute = spatial::SpatGraphComputeBatch::create( rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs); auto laneArg = compute.getLaneArgument(); if (!laneArg) @@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) { }); } -bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } +bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } -bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } +bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp b/src/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp new file mode 100644 index 0000000..dd5a722 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir { + +mlir::FailureOr +lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp, + std::optional rowStripInput, + bool emitRowStripLayout, + mlir::PatternRewriter& rewriter); + +mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp); +mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp new file mode 100644 index 0000000..d4fd626 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp @@ -0,0 +1,200 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseMap.h" + +#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PIMPasses.h" +#include "src/Accelerators/PIM/Pass/PIMPasses.h" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static constexpr StringLiteral kLogicalLayout = "nchw"; +static constexpr StringLiteral kDenseLayout = "dense_nchw"; +static constexpr StringLiteral kRowStripLayout = "nchw_row_strip"; +static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw"; + +enum class SelectedLayout { + DenseNchw, + NchwRowStrip, +}; + +static SelectedLayout getSelectedLayout(llvm::DenseMap& layouts, Value value) { + auto it = layouts.find(value); + return it == layouts.end() ? SelectedLayout::DenseNchw : it->second; +} + +static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap& layouts) { + if (auto reluPlan = dyn_cast(user)) + return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip; + if (auto convPlan = dyn_cast(user)) + return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip; + return false; +} + +static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap& layouts) { + for (Operation* user : value.getUsers()) { + if (usesSelectedRowStrip(user, layouts)) + continue; + // Dense-only users must be materialized explicitly. + continue; + } + return true; +} + +static std::pair, SmallVector> buildRowStripMetadata(RankedTensorType type) { + SmallVector offsets; + SmallVector sizes; + const int64_t channels = type.getDimSize(1); + const int64_t height = type.getDimSize(2); + const int64_t width = type.getDimSize(3); + offsets.reserve(height * 4); + sizes.reserve(height * 4); + for (int64_t row = 0; row < height; ++row) { + offsets.append({0, 0, row, 0}); + sizes.append({1, channels, 1, width}); + } + return {offsets, sizes}; +} + +static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan, + llvm::DenseMap& layouts) { + SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput()); + if (inputLayout == SelectedLayout::NchwRowStrip) + return succeeded(canConsumeAndProduceRowStrip(convPlan)); + return succeeded(canLowerConvPlanToRowStrip(convPlan)); +} + +static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan, + llvm::DenseMap& layouts) { + if (!canSelectConvRowStrip(convPlan, layouts)) + return SelectedLayout::DenseNchw; + if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts)) + return SelectedLayout::DenseNchw; + return SelectedLayout::NchwRowStrip; +} + +static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan, + llvm::DenseMap& layouts) { + if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip) + return SelectedLayout::DenseNchw; + if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts)) + return SelectedLayout::DenseNchw; + return SelectedLayout::NchwRowStrip; +} + +static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) { + auto outputType = cast(value.getType()); + auto [offsets, sizes] = buildRowStripMetadata(outputType); + return spatial::SpatReconciliatorOp::create(rewriter, + value.getLoc(), + outputType, + value, + rewriter.getStringAttr(kLogicalLayout), + rewriter.getStringAttr(kRowStripLayout), + rewriter.getDenseI64ArrayAttr(offsets), + rewriter.getDenseI64ArrayAttr(sizes), + rewriter.getStringAttr(kRowStripIndexMap)); +} + +static void materializeDenseUses(IRRewriter& rewriter, + Value layoutValue, + llvm::DenseMap& layouts) { + SmallVector denseUses; + for (OpOperand& use : layoutValue.getUses()) { + if (usesSelectedRowStrip(use.getOwner(), layouts)) + continue; + denseUses.push_back(&use); + } + + for (OpOperand* use : denseUses) { + Operation* owner = use->getOwner(); + rewriter.setInsertionPoint(owner); + auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter, + owner->getLoc(), + use->get().getType(), + use->get(), + rewriter.getStringAttr(kLogicalLayout), + rewriter.getStringAttr(kRowStripLayout), + rewriter.getStringAttr(kDenseLayout)); + use->set(materialized.getResult()); + } +} + +struct SpatialLayoutPlanningPass final : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass) + + StringRef getArgument() const override { return "spatial-layout-planning"; } + StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; } + + void runOnOperation() override { + auto entryFunc = getPimEntryFunc(getOperation()); + if (failed(entryFunc)) { + getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning"); + signalPassFailure(); + return; + } + + func::FuncOp funcOp = *entryFunc; + IRRewriter rewriter(&getContext()); + llvm::DenseMap layouts; + + bool changed = true; + while (changed) { + changed = false; + for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) { + if (auto convPlan = dyn_cast(&op)) { + SelectedLayout selected = chooseConvLayout(convPlan, layouts); + if (layouts[convPlan.getResult()] != selected) { + layouts[convPlan.getResult()] = selected; + changed = true; + } + continue; + } + if (auto reluPlan = dyn_cast(&op)) { + SelectedLayout selected = chooseReluLayout(reluPlan, layouts); + if (layouts[reluPlan.getResult()] != selected) { + layouts[reluPlan.getResult()] = selected; + changed = true; + } + continue; + } + } + } + + for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) { + Value producedValue; + if (auto convPlan = dyn_cast(&op)) + producedValue = convPlan.getResult(); + else if (auto reluPlan = dyn_cast(&op)) + producedValue = reluPlan.getResult(); + else + continue; + + if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip) + continue; + + rewriter.setInsertionPointAfter(&op); + auto reconciliator = insertRowStripReconciliator(rewriter, producedValue); + rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator); + materializeDenseUses(rewriter, reconciliator.getResult(), layouts); + } + if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) { + getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning"); + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createSpatialLayoutPlanningPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index daa3943..d262ea5 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -102,7 +102,7 @@ static FailureOr materializeExternalTensorValue(IRRewriter& rewriter, return mapper.lookup(value); } -static FailureOr> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, +static FailureOr> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, size_t& fallbackCoreId) { if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); @@ -171,7 +171,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter, } // namespace -LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, +LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, IRRewriter& rewriter) { Location loc = computeBatchOp.getLoc(); Block& oldBlock = computeBatchOp.getBody().front(); diff --git a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp index 1860a9d..9efc635 100644 --- a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp +++ b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp @@ -17,10 +17,10 @@ std::optional getDirectComputeLikeInputIndex(Operation* owner, unsigne return operandNumber - inputBegin; }; - if (auto compute = dyn_cast(owner)) + if (auto compute = dyn_cast(owner)) return getInputIndex(owner, compute.getInputs().size()); - if (auto computeBatch = dyn_cast(owner)) + if (auto computeBatch = dyn_cast(owner)) return getInputIndex(owner, computeBatch.getInputs().size()); return std::nullopt; @@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, Value replacement) { Block& body = owner->getRegion(0).front(); BlockArgument bodyArgument; - if (auto compute = dyn_cast(owner)) { + if (auto compute = dyn_cast(owner)) { auto computeArg = compute.getInputArgument(inputIndex); assert(computeArg && "expected compute input block argument"); bodyArgument = *computeArg; } else { - auto batchArg = cast(owner).getInputArgument(inputIndex); + auto batchArg = cast(owner).getInputArgument(inputIndex); assert(batchArg && "expected compute_batch input block argument"); bodyArgument = *batchArg; } @@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, rewriter.startOpModification(owner); bodyArgument.replaceAllUsesWith(replacement); - if (auto compute = dyn_cast(owner)) + if (auto compute = dyn_cast(owner)) compute.getInputsMutable().erase(inputIndex); else - cast(owner).getInputsMutable().erase(inputIndex); + cast(owner).getInputsMutable().erase(inputIndex); body.eraseArgument(bodyArgIndex); rewriter.finalizeOpModification(owner); } diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 142bc9a..696b0bc 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -55,7 +55,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite } } -static FailureOr getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { +static FailureOr getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) { if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id"); auto checkedCoreId = @@ -66,7 +66,7 @@ static FailureOr getPimCoreIdForComputeOp(spatial::SpatCompute computeO return *checkedCoreId; } -static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, +static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) @@ -104,13 +104,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, return success(); } -static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, +static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp, IRRewriter& rewriter, OperationFolder& constantFolder) { if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) return false; if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { - return isa(user); + return isa(user); })) return false; @@ -145,7 +145,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute } // namespace -LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp, +LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp, IRRewriter& rewriter, OperationFolder& constantFolder) { Location loc = computeOp->getLoc(); diff --git a/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp index 31d8823..d347cee 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp @@ -10,6 +10,14 @@ using namespace mlir; namespace onnx_mlir { namespace { +static void copyRaptorDebugAttrs(Operation* source, Operation* target) { + for (NamedAttribute attr : source->getAttrs()) { + StringRef name = attr.getName().strref(); + if (name.starts_with("raptor.")) + target->setAttr(attr.getName(), attr.getValue()); + } +} + struct ChannelSendLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern { auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput()); if (failed(sizeAttr)) return failure(); - pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId()); + auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId()); + copyRaptorDebugAttrs(op.getOperation(), send.getOperation()); rewriter.eraseOp(op); return success(); } @@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult()); if (failed(sizeAttr)) return failure(); - Value received = pim::PimReceiveOp::create( - rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId()) - .getOutput(); + auto receive = pim::PimReceiveOp::create( + rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId()); + copyRaptorDebugAttrs(op.getOperation(), receive.getOperation()); + Value received = receive.getOutput(); rewriter.replaceOp(op, received); return success(); } diff --git a/src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp index 3ac2289..1bd2179 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp @@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetUses()) { - if (isa(uses.getOwner())) { + if (isa(uses.getOwner())) { if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber())) return failure(); } @@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetUses())) { - if (auto spatCompute = dyn_cast(uses.getOwner())) { + if (auto spatCompute = dyn_cast(uses.getOwner())) { auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber()); if (!inputIndex) return failure(); @@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern(uses.getOwner())) { + else if (auto spatComputeBatch = dyn_cast(uses.getOwner())) { auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber()); if (!inputIndex) return failure(); @@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetParentOfType()) { + if (auto spatCompute = uses.getOwner()->getParentOfType()) { rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatToExtract.contains(spatCompute.getOperation())) { auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); @@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetParentOfType()) { + else if (auto spatComputeBatch = uses.getOwner()->getParentOfType()) { rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) { auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); @@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(argUser)) { + if (auto spatCompute = dyn_cast(argUser)) { auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber()); if (!inputIndex) return failure(); @@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(argUser)) { + else if (auto spatComputeBatch = dyn_cast(argUser)) { auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber()); if (!inputIndex) return failure(); diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index c61ffb1..ef23565 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho return pim::checkedCast(*byteOffset, anchor, fieldName); } -static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, +static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp, SmallVectorImpl& helperChain) { if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) return failure(); @@ -212,7 +212,7 @@ static std::optional analyzeConcatReturnUse(Value value) { } SmallVector helperChain; - if (auto helperCompute = dyn_cast(currentUser)) { + if (auto helperCompute = dyn_cast(currentUser)) { if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) return std::nullopt; @@ -643,7 +643,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low } raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath( - spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) { + spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) { return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter); } @@ -656,7 +656,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { Operation* onlyUser = *op->getUsers().begin(); isExclusivelyOwnedByReturnChain = - isa(onlyUser) + isa(onlyUser) || isReturnHelperChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) @@ -669,7 +669,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret return; } - if (auto computeOp = dyn_cast(op)) { + if (auto computeOp = dyn_cast(op)) { markOpToRemove(computeOp); if (!computeOp.getInputs().empty()) for (Value input : computeOp.getInputs()) diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index f057626..2f1102e 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -25,9 +25,11 @@ #include #include +#include "Common/IR/ShapeUtils.hpp" #include "Common/IR/ConstantUtils.hpp" #include "Common/PimCommon.hpp" #include "Common/Support/CheckedArithmetic.hpp" +#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Patterns.hpp" @@ -97,6 +99,64 @@ static FailureOr createZeroedDeviceHVector(IRRewriter& rewriter, .getOutput(); } +static bool isHostBackedMemRefValue(Value value) { + while (Operation* definingOp = value.getDefiningOp()) { + if (auto subviewOp = dyn_cast(definingOp)) { + value = subviewOp.getSource(); + continue; + } + if (auto castOp = dyn_cast(definingOp)) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = dyn_cast(definingOp)) { + value = expandOp.getSrc(); + continue; + } + return isa(definingOp); + } + return false; +} + +static bool isHostBackedTensorValue(Value value) { + while (Operation* definingOp = value.getDefiningOp()) { + if (auto extractSliceOp = dyn_cast(definingOp)) { + auto sourceType = dyn_cast(extractSliceOp.getSource().getType()); + auto resultType = dyn_cast(extractSliceOp.getResult().getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return false; + if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(), + extractSliceOp.getMixedOffsets(), + extractSliceOp.getStaticSizes(), + extractSliceOp.getStaticStrides())) { + return false; + } + value = extractSliceOp.getSource(); + continue; + } + if (auto collapseOp = dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = dyn_cast(definingOp)) { + value = expandOp.getSrc(); + continue; + } + if (auto castOp = dyn_cast(definingOp)) { + value = castOp.getSource(); + continue; + } + if (auto toTensorOp = dyn_cast(definingOp)) + return isHostBackedMemRefValue(toTensorOp.getBuffer()); + return false; + } + return false; +} + static FailureOr padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { auto vectorType = cast(vector.getType()); @@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size"); if (failed(sizeAttr)) return failure(); + if (isHostBackedTensorValue(vector)) { + return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr) + .getOutput(); + } return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput(); } @@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { return; } func::FuncOp funcOp = *entryFunc; + if (failed(verifyScheduledSpatialInvariants(funcOp))) { + funcOp.emitOpError( + "RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim"); + signalPassFailure(); + return; + } IRRewriter rewriter(&getContext()); OperationFolder constantFolder(&getContext()); @@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { return; } - for (auto computeOp : funcOp.getOps()) { + for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) { - computeOp.emitOpError("failed to lower spat.compute to pim.core"); + computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core"); signalPassFailure(); return; } } - for (auto computeBatchOp : funcOp.getOps()) { + for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) { - computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch"); + computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch"); signalPassFailure(); return; } @@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables( }; for (auto& op : funcOp.getBody().getOps()) - if (auto computeOp = dyn_cast(op)) { + if (auto computeOp = dyn_cast(op)) { if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) continue; for (auto getGlobal : computeOp.getOps()) { diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp index 7508f3c..3c82c60 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp @@ -41,8 +41,11 @@ private: mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); mlir::LogicalResult - lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder); - mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter); + lowerComputeOp(spatial::SpatScheduledCompute computeOp, + mlir::IRRewriter& rewriter, + mlir::OperationFolder& constantFolder); + mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, + mlir::IRRewriter& rewriter); enum class ReturnPathLoweringResult { Handled, @@ -51,7 +54,7 @@ private: }; void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter); - ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp, + ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp, mlir::OpResult result, mlir::Value yieldValue, mlir::IRRewriter& rewriter); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index b553766..58fd8d6 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -13,10 +13,13 @@ using namespace bufferization; namespace onnx_mlir::pim { -FailureOr materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { +FailureOr materializeContiguousInputMemRef(Value memrefValue, + Location loc, + RewriterBase& rewriter, + const StaticValueKnowledge& knowledge) { bool isContiguous = - succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)); - if (isContiguous && isDeviceLocalPimAddress(memrefValue)) + succeeded(resolveContiguousAddress(memrefValue, knowledge)) || succeeded(compileContiguousAddressExpr(memrefValue)); + if (isContiguous && isDeviceLocalPimAddress(memrefValue, knowledge)) return memrefValue; auto shapedType = cast(memrefValue.getType()); @@ -32,7 +35,7 @@ FailureOr materializeContiguousInputMemRef(Value memrefValue, Location lo if (failed(sizeAttr)) return failure(); - if (isHostBackedPimAddress(memrefValue)) { + if (isHostBackedPimAddress(memrefValue, knowledge)) { return PimMemCopyHostToDevOp::create( rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr) .getOutput(); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp index 590afec..72aa41a 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp @@ -3,10 +3,15 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/PatternMatch.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" + namespace onnx_mlir::pim { llvm::FailureOr -materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); +materializeContiguousInputMemRef(mlir::Value memrefValue, + mlir::Location loc, + mlir::RewriterBase& rewriter, + const onnx_mlir::StaticValueKnowledge& knowledge = {}); mlir::Value allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index db77dfa..ceed7c1 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -15,6 +15,26 @@ using namespace bufferization; namespace onnx_mlir { namespace pim { +static StaticValueKnowledge getEnclosingBufferizationKnowledge(Operation* op) { + StaticValueKnowledge knowledge; + + if (auto coreBatchOp = op->getParentOfType()) { + knowledge.indexValues[coreBatchOp.getLaneArgument()] = 0; + for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights())) + knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight; + for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs())) + knowledge.aliases[coreBatchOp.getInputArgument(index)] = input; + return knowledge; + } + + if (auto coreOp = op->getParentOfType()) { + for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) + knowledge.aliases[coreOp.getWeightArgument(index)] = weight; + } + + return knowledge; +} + struct MemCopyHostToDevOpInterface : DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation* op, @@ -148,7 +168,8 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguous = + materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguous)) return failure(); inputs.push_back(*contiguous); @@ -182,7 +203,8 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModelgetLoc(), rewriter); + auto contiguousInput = + materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousInput)) return failure(); @@ -410,7 +432,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousInput = + materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousInput)) return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); @@ -456,7 +479,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousInput = + materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousInput)) return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); @@ -497,10 +521,12 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousLhs = + materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousLhs)) return failure(); - auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + auto contiguousRhs = + materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousRhs)) return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); @@ -534,10 +560,12 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousLhs = + materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousLhs)) return failure(); - auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + auto contiguousRhs = + materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousRhs)) return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); @@ -574,7 +602,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousInput = + materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op)); if (failed(contiguousInput)) return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index bb8e35b..b64a8f7 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -116,6 +116,36 @@ lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const return success(); } +static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyHostToDevOp copyOp, const StaticValueKnowledge& knowledge) { + bool sourceIsHost = isHostBackedPimAddress(copyOp.getHostSource(), knowledge); + bool targetIsHost = isHostBackedPimAddress(copyOp.getDeviceTarget(), knowledge); + bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getHostSource(), knowledge); + bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceTarget(), knowledge); + if (!sourceIsHost || !targetIsDevice || targetIsHost || sourceIsDevice) + return copyOp.emitOpError("pim.memcp_hd requires a host-backed source and a device-local target"); + return success(); +} + +static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyDevToHostOp copyOp, const StaticValueKnowledge& knowledge) { + bool sourceIsHost = isHostBackedPimAddress(copyOp.getDeviceSource(), knowledge); + bool targetIsHost = isHostBackedPimAddress(copyOp.getHostTarget(), knowledge); + bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceSource(), knowledge); + bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getHostTarget(), knowledge); + if (!targetIsHost || !sourceIsDevice || sourceIsHost || targetIsDevice) + return copyOp.emitOpError("pim.memcp_dh requires a device-local source and a host-backed target"); + return success(); +} + +static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyOp copyOp, const StaticValueKnowledge& knowledge) { + bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge); + bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge); + bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge); + bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge); + if (!sourceIsDevice || !targetIsDevice || sourceIsHost || targetIsHost) + return copyOp.emitOpError("pim.memcp requires device-local source and target operands"); + return success(); +} + struct PimBufferizationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) StringRef getArgument() const override { return "bufferize-pim"; } @@ -129,6 +159,7 @@ struct PimBufferizationPass : PassWrapper(&op); copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge))) + hasFailure = true; + if (auto copyOp = dyn_cast(&op); + copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge))) + hasFailure = true; + if (auto copyOp = dyn_cast(&op); + copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge))) + hasFailure = true; + return success(); + }); + }; + + moduleOp.walk([&](pim::PimCoreOp coreOp) { verifyWithKnowledge(coreOp, seedCoreKnowledge(coreOp)); }); + moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { + StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, 0); + verifyWithKnowledge(coreBatchOp, knowledge); + }); + return success(!hasFailure); +} + std::unique_ptr createPimBufferizationPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp index fd870b8..b964097 100644 --- a/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp @@ -96,8 +96,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override { - auto coreOp = mapOp->getParentOfType(); - if (!coreOp) + if (!mapOp->getParentOfType() && !mapOp->getParentOfType()) return failure(); auto initType = dyn_cast(mapOp.getInit().getType()); diff --git a/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt index 67009ac..3afe2bf 100644 --- a/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt @@ -5,6 +5,7 @@ add_pim_library(OMPimVerification LINK_LIBS PUBLIC OMPimCommon + OMPimCompilerOptions OMPimBufferization PimOps SpatialOps diff --git a/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp index eb07f1c..9612d25 100644 --- a/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp @@ -5,12 +5,17 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" + +#include #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp" @@ -143,6 +148,479 @@ static bool isHostAddressableValue(Value value, const StaticValueKnowledge& know return isa_and_nonnull(base.getDefiningOp()); } + +enum class CommunicationEventKind { Send, Receive }; + +struct CommunicationEvent { + CommunicationEventKind kind = CommunicationEventKind::Send; + int64_t coreId = 0; + int64_t peerCoreId = 0; + int64_t size = 0; + uint64_t ordinal = 0; + std::optional minChannelId; + std::string materializer; + std::optional traceId; + std::optional commOrder; + std::optional traceClassId; + std::optional traceBlockOrdinal; + std::string traceKind; + std::string tracePhase; + std::string traceClassKind; + std::string tracePayload; + std::string traceMessages; + std::string tracePrevOp; + std::string traceNextOp; + Operation* op = nullptr; +}; + +using CommunicationEventVector = SmallVector; + +static StringRef getCommunicationEventKindName(CommunicationEventKind kind) { + return kind == CommunicationEventKind::Send ? "send" : "receive"; +} + +constexpr StringLiteral kRaptorMinChannelIdAttr = "raptor.min_channel_id"; +constexpr StringLiteral kRaptorMaterializerAttr = "raptor.materializer"; +constexpr StringLiteral kRaptorCommOrderAttr = "raptor.comm_order"; +constexpr StringLiteral kRaptorCommTraceIdAttr = "raptor.comm_trace_id"; +constexpr StringLiteral kRaptorCommTraceKindAttr = "raptor.comm_trace_kind"; +constexpr StringLiteral kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase"; +constexpr StringLiteral kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id"; +constexpr StringLiteral kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind"; +constexpr StringLiteral kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal"; +constexpr StringLiteral kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload"; +constexpr StringLiteral kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages"; +constexpr StringLiteral kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op"; +constexpr StringLiteral kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op"; + +static std::optional getNearestIntegerAttr(Operation* op, StringRef name) { + for (Operation* current = op; current; current = current->getParentOp()) + if (auto attr = current->getAttrOfType(name)) + return attr.getInt(); + return std::nullopt; +} + +static std::string getNearestStringAttr(Operation* op, StringRef name) { + for (Operation* current = op; current; current = current->getParentOp()) + if (auto attr = current->getAttrOfType(name)) + return attr.getValue().str(); + return ""; +} + +static std::string formatLocation(Location loc) { + std::string text; + llvm::raw_string_ostream os(text); + loc.print(os); + return os.str(); +} + +static std::string formatOperationSummary(Operation* op) { + std::string text; + llvm::raw_string_ostream os(text); + OpPrintingFlags flags; + flags.skipRegions(); + flags.elideLargeElementsAttrs(); + op->print(os, flags); + return os.str(); +} + +static std::string formatCommunicationEvent(const CommunicationEvent& event) { + std::string text; + llvm::raw_string_ostream os(text); + os << "core " << event.coreId << " " << getCommunicationEventKindName(event.kind) << " " + << (event.kind == CommunicationEventKind::Send ? "to" : "from") << " " << event.peerCoreId + << " size " << event.size << "B ordinal " << event.ordinal; + if (event.minChannelId) + os << " min_channel " << *event.minChannelId; + if (event.commOrder) + os << " comm_order " << *event.commOrder; + if (!event.materializer.empty()) + os << " materializer " << event.materializer; + if (event.traceId) + os << " trace#" << *event.traceId; + if (!event.tracePhase.empty()) + os << " phase " << event.tracePhase; + if (event.traceClassId) + os << " class " << event.traceClassKind << "#" << *event.traceClassId; + if (event.traceBlockOrdinal) + os << " block_ordinal " << *event.traceBlockOrdinal; + if (!event.tracePayload.empty()) + os << " payload " << event.tracePayload; + if (!event.traceMessages.empty()) + os << " messages {" << event.traceMessages << "}"; + if (!event.tracePrevOp.empty() || !event.traceNextOp.empty()) + os << " inserted_between [" << event.tracePrevOp << " | " << event.traceNextOp << "]"; + return os.str(); +} + +static bool areMatchedCommunicationEvents(const CommunicationEvent& lhs, const CommunicationEvent& rhs) { + if (lhs.coreId != rhs.peerCoreId || lhs.peerCoreId != rhs.coreId || lhs.size != rhs.size) + return false; + + return (lhs.kind == CommunicationEventKind::Send && rhs.kind == CommunicationEventKind::Receive) + || (lhs.kind == CommunicationEventKind::Receive && rhs.kind == CommunicationEventKind::Send); +} + + +static std::optional findMatchingCounterpartIndex(const CommunicationEventVector& events, + const CommunicationEvent& event, + size_t begin) { + for (size_t index = begin; index < events.size(); ++index) + if (areMatchedCommunicationEvents(event, events[index])) + return index; + return std::nullopt; +} + +static void printCounterpartProbe(llvm::raw_ostream& os, + const DenseMap& coreEvents, + const DenseMap& programCounters, + const CommunicationEvent& blockedEvent) { + auto peerEventsIt = coreEvents.find(blockedEvent.peerCoreId); + if (peerEventsIt == coreEvents.end()) { + os << " no local stream was collected for peer core " << blockedEvent.peerCoreId << "\n"; + return; + } + + const CommunicationEventVector& peerEvents = peerEventsIt->second; + size_t peerPc = 0; + auto peerPcIt = programCounters.find(blockedEvent.peerCoreId); + if (peerPcIt != programCounters.end()) + peerPc = peerPcIt->second; + + os << " counterpart probe for " << formatCommunicationEvent(blockedEvent) << "\n"; + os << " peer core " << blockedEvent.peerCoreId << " current pc " << peerPc << " of " << peerEvents.size() + << "\n"; + + std::optional nextMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, peerPc); + std::optional anyMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, 0); + + if (!nextMatch && !anyMatch) { + os << " no matching counterpart exists anywhere in the peer stream\n"; + return; + } + + if (!nextMatch && anyMatch) { + os << " matching counterpart exists only before the peer pc at ordinal " << *anyMatch + << "; this usually means the static stream expansion or ordering metadata is inconsistent\n"; + os << " " << formatCommunicationEvent(peerEvents[*anyMatch]) << "\n"; + os << " op: " << formatOperationSummary(peerEvents[*anyMatch].op) << "\n"; + return; + } + + const CommunicationEvent& match = peerEvents[*nextMatch]; + os << " next matching counterpart is at peer ordinal " << *nextMatch << " (distance +" + << (*nextMatch >= peerPc ? *nextMatch - peerPc : 0) << ")\n"; + os << " " << formatCommunicationEvent(match) << "\n"; + os << " op: " << formatOperationSummary(match.op) << "\n"; + + if (*nextMatch == peerPc) + return; + + os << " peer operations blocking before that counterpart:\n"; + size_t end = std::min(peerEvents.size(), std::min(*nextMatch + static_cast(1), peerPc + static_cast(12))); + for (size_t index = peerPc; index < end; ++index) { + os << (index == peerPc ? " pc => " : " ") << "#" << index << " " + << formatCommunicationEvent(peerEvents[index]) << "\n"; + os << " op: " << formatOperationSummary(peerEvents[index].op) << "\n"; + } + if (end <= *nextMatch) + os << " ... " << (*nextMatch - end + 1) << " more peer communication event(s) before the counterpart\n"; +} + +static CommunicationEvent makeCommunicationEvent(CommunicationEventKind kind, + int64_t coreId, + int64_t peerCoreId, + int64_t size, + uint64_t ordinal, + Operation* op) { + return CommunicationEvent {kind, + coreId, + peerCoreId, + size, + ordinal, + getNearestIntegerAttr(op, kRaptorMinChannelIdAttr), + getNearestStringAttr(op, kRaptorMaterializerAttr), + getNearestIntegerAttr(op, kRaptorCommTraceIdAttr), + getNearestIntegerAttr(op, kRaptorCommOrderAttr), + getNearestIntegerAttr(op, kRaptorCommTraceClassIdAttr), + getNearestIntegerAttr(op, kRaptorCommTraceBlockOrdinalAttr), + getNearestStringAttr(op, kRaptorCommTraceKindAttr), + getNearestStringAttr(op, kRaptorCommTracePhaseAttr), + getNearestStringAttr(op, kRaptorCommTraceClassKindAttr), + getNearestStringAttr(op, kRaptorCommTracePayloadAttr), + getNearestStringAttr(op, kRaptorCommTraceMessagesAttr), + getNearestStringAttr(op, kRaptorCommTracePrevOpAttr), + getNearestStringAttr(op, kRaptorCommTraceNextOpAttr), + op}; +} + +static LogicalResult appendCoreCommunicationEvents(Block& block, + int64_t coreId, + const StaticValueKnowledge& initialKnowledge, + SmallVectorImpl& events, + pim::CappedDiagnosticReporter& diagnostics) { + return walkPimCoreBlock(block, initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) { + if (auto sendOp = dyn_cast(&op)) { + auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge); + if (failed(targetCoreId)) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("cannot statically resolve send target core for PIM communication deadlock check"); + }); + return failure(); + } + + events.push_back(makeCommunicationEvent(CommunicationEventKind::Send, + coreId, + *targetCoreId, + sendOp.getSize(), + static_cast(events.size()), + &op)); + return success(); + } + + if (auto receiveOp = dyn_cast(&op)) { + auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge); + if (failed(sourceCoreId)) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("cannot statically resolve receive source core for PIM communication deadlock check"); + }); + return failure(); + } + + events.push_back(makeCommunicationEvent(CommunicationEventKind::Receive, + coreId, + *sourceCoreId, + receiveOp.getSize(), + static_cast(events.size()), + &op)); + return success(); + } + + return success(); + }); +} + +static void printCommunicationWindow(llvm::raw_ostream& os, + const DenseMap& coreEvents, + int64_t coreId, + size_t pc, + unsigned radius = 4) { + auto eventsIt = coreEvents.find(coreId); + if (eventsIt == coreEvents.end()) + return; + + const CommunicationEventVector& events = eventsIt->second; + size_t begin = pc > radius ? pc - radius : 0; + size_t end = std::min(events.size(), pc + static_cast(radius) + 1); + os << " local stream for core " << coreId << " around pc " << pc << " of " << events.size() << ":\n"; + for (size_t index = begin; index < end; ++index) { + os << (index == pc ? " => " : " ") << formatCommunicationEvent(events[index]) << "\n"; + os << " op: " << formatOperationSummary(events[index].op) << "\n"; + } +} + +static void printCommunicationDeadlockReport(const DenseMap& coreEvents, + const DenseMap& programCounters, + ArrayRef cycle) { + llvm::errs() << "\n=== PIM static communication deadlock report ===\n"; + llvm::errs() << "wait cycle:"; + for (int64_t coreId : cycle) + llvm::errs() << " " << coreId; + if (!cycle.empty()) + llvm::errs() << " -> " << cycle.front(); + llvm::errs() << "\n\nblocked heads:\n"; + + for (int64_t coreId : cycle) { + auto eventsIt = coreEvents.find(coreId); + auto pcIt = programCounters.find(coreId); + if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size()) + continue; + const CommunicationEvent& event = eventsIt->second[pcIt->second]; + llvm::errs() << " " << formatCommunicationEvent(event) << "\n"; + llvm::errs() << " loc: " << formatLocation(event.op->getLoc()) << "\n"; + llvm::errs() << " op : " << formatOperationSummary(event.op) << "\n"; + } + + llvm::errs() << "\npeer counterpart probes:\n"; + for (int64_t coreId : cycle) { + auto eventsIt = coreEvents.find(coreId); + auto pcIt = programCounters.find(coreId); + if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size()) + continue; + printCounterpartProbe(llvm::errs(), coreEvents, programCounters, eventsIt->second[pcIt->second]); + } + + llvm::errs() << "\nlocal communication streams:\n"; + for (int64_t coreId : cycle) { + auto pcIt = programCounters.find(coreId); + if (pcIt == programCounters.end()) + continue; + printCommunicationWindow(llvm::errs(), coreEvents, coreId, pcIt->second); + } + llvm::errs() << "=== end PIM static communication deadlock report ===\n\n"; +} + +static void emitCommunicationDeadlockCycle(ModuleOp moduleOp, + const DenseMap& coreEvents, + const DenseMap& programCounters, + ArrayRef cycle) { + printCommunicationDeadlockReport(coreEvents, programCounters, cycle); + + auto diagnostic = moduleOp.emitError() + << "PIM communication deadlock check found a blocking send/receive cycle while statically simulating the " + "expanded per-core communication streams; see the PIM static communication deadlock report above"; + + for (int64_t coreId : cycle) { + auto eventsIt = coreEvents.find(coreId); + auto pcIt = programCounters.find(coreId); + if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size()) + continue; + + const CommunicationEvent& event = eventsIt->second[pcIt->second]; + Diagnostic& note = diagnostic.attachNote(event.op->getLoc()); + note << formatCommunicationEvent(event); + if (!event.materializer.empty()) + note << " emitted by " << event.materializer; + } +} + +static FailureOr> findCommunicationWaitCycle( + const DenseMap& coreEvents, + const DenseMap& programCounters) { + for (const auto& [startCoreId, events] : coreEvents) { + auto startPcIt = programCounters.find(startCoreId); + if (startPcIt == programCounters.end() || startPcIt->second >= events.size()) + continue; + + DenseMap positionInPath; + SmallVector path; + int64_t currentCoreId = startCoreId; + while (true) { + auto eventsIt = coreEvents.find(currentCoreId); + auto pcIt = programCounters.find(currentCoreId); + if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size()) + break; + + auto positionIt = positionInPath.find(currentCoreId); + if (positionIt != positionInPath.end()) { + SmallVector cycle; + for (size_t index = positionIt->second; index < path.size(); ++index) + cycle.push_back(path[index]); + return cycle; + } + + positionInPath[currentCoreId] = path.size(); + path.push_back(currentCoreId); + currentCoreId = eventsIt->second[pcIt->second].peerCoreId; + } + } + + return failure(); +} + +static LogicalResult verifyNoStaticCommunicationDeadlock(ModuleOp moduleOp, + pim::CappedDiagnosticReporter& diagnostics) { + DenseMap coreEvents; + bool hasFailure = false; + + for (func::FuncOp funcOp : moduleOp.getOps()) { + if (funcOp.isExternal()) + continue; + + for (Operation& op : funcOp.getBody().front().getOperations()) { + if (auto coreOp = dyn_cast(&op)) { + int64_t coreId = coreOp.getCoreId(); + if (failed(appendCoreCommunicationEvents( + coreOp.getBody().front(), coreId, StaticValueKnowledge {}, coreEvents[coreId], diagnostics))) + hasFailure = true; + continue; + } + + if (auto coreBatchOp = dyn_cast(&op)) { + SmallVector coreIds = getBatchCoreIds(coreBatchOp); + size_t laneCount = static_cast(coreBatchOp.getLaneCount()); + for (size_t lane = 0; lane < laneCount; ++lane) { + StaticValueKnowledge laneKnowledge; + laneKnowledge.indexValues[coreBatchOp.getLaneArgument()] = static_cast(lane); + for (unsigned inputIndex = 0; inputIndex < coreBatchOp.getInputs().size(); ++inputIndex) + laneKnowledge.aliases[coreBatchOp.getInputArgument(inputIndex)] = coreBatchOp.getInputs()[inputIndex]; + + SmallVector laneCoreIds = getLaneChunkCoreIds(coreIds, laneCount, static_cast(lane)); + for (int32_t coreId : laneCoreIds) { + if (failed(appendCoreCommunicationEvents( + coreBatchOp.getBody().front(), coreId, laneKnowledge, coreEvents[coreId], diagnostics))) + hasFailure = true; + } + } + } + } + } + + if (hasFailure) + return failure(); + + DenseMap programCounters; + for (const auto& [coreId, events] : coreEvents) + programCounters[coreId] = 0; + + while (true) { + bool madeProgress = false; + for (const auto& [coreId, events] : coreEvents) { + size_t pc = programCounters[coreId]; + if (pc >= events.size()) + continue; + + const CommunicationEvent& event = events[pc]; + auto peerEventsIt = coreEvents.find(event.peerCoreId); + if (peerEventsIt == coreEvents.end()) + continue; + + size_t peerPc = programCounters[event.peerCoreId]; + if (peerPc >= peerEventsIt->second.size()) + continue; + + const CommunicationEvent& peerEvent = peerEventsIt->second[peerPc]; + if (!areMatchedCommunicationEvents(event, peerEvent)) + continue; + + ++programCounters[coreId]; + ++programCounters[event.peerCoreId]; + madeProgress = true; + } + + if (madeProgress) + continue; + + bool allDone = true; + for (const auto& [coreId, events] : coreEvents) { + if (programCounters[coreId] < events.size()) { + allDone = false; + break; + } + } + if (allDone) + return success(); + + auto cycle = findCommunicationWaitCycle(coreEvents, programCounters); + if (succeeded(cycle)) { + emitCommunicationDeadlockCycle(moduleOp, coreEvents, programCounters, *cycle); + return failure(); + } + + auto diagnostic = moduleOp.emitError() + << "PIM communication deadlock check stalled without finding a closed wait cycle; this usually means a " + "send/receive peer is missing or ordered after a finished core"; + for (const auto& [coreId, events] : coreEvents) { + size_t pc = programCounters[coreId]; + if (pc >= events.size()) + continue; + const CommunicationEvent& event = events[pc]; + diagnostic.attachNote(event.op->getLoc()) << formatCommunicationEvent(event); + } + return failure(); + } +} + struct VerificationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) @@ -212,11 +690,18 @@ struct VerificationPass : PassWrapper> } } + bool hasFailure = false; + if (pimDetectCommunicationDeadlock && failed(verifyNoStaticCommunicationDeadlock(moduleOp, diagnostics))) + hasFailure = true; + if (diagnostics.hasFailure()) { diagnostics.emitSuppressedSummary(moduleOp, "verification failures"); moduleOp.emitError("PIM codegen verification failed; see diagnostics above"); - signalPassFailure(); + hasFailure = true; } + + if (hasFailure) + signalPassFailure(); } private: diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 56347d8..6e1d701 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -26,7 +26,7 @@ def SpatTensor : // Execution //===----------------------------------------------------------------------===// -def SpatCompute : SpatOp<"compute", +class SpatComputeLikeBase : SpatOp]> { let summary = "Compute region with attached constant weights"; @@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute", let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let hasFolder = 1; + let hasCustomAssemblyFormat = 1; +} + +def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> { let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); @@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute", std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); - ::mlir::FailureOr> + ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; - - let hasVerifier = 1; - let hasFolder = 1; - let hasCustomAssemblyFormat = 1; } -def SpatComputeBatch : SpatOp<"compute_batch", +def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> { + let extraClassDeclaration = [{ + std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); + std::optional> + insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); + std::optional> + insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); + ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); + ::mlir::FailureOr> + insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); + }]; +} + +class SpatComputeBatchLikeBase : SpatOp]> { let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs"; @@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch", let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> { let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getLaneArgument(); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); @@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch", std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); - ::mlir::FailureOr> + ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; +} - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; +def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> { + let extraClassDeclaration = [{ + std::optional<::mlir::BlockArgument> getLaneArgument(); + std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx); + std::optional> + insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); + std::optional> + insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); + ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); + ::mlir::FailureOr> + insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); + }]; } def SpatInParallelOp : SpatOp<"in_parallel", [ Pure, Terminator, DeclareOpInterfaceMethods, - HasParent<"SpatComputeBatch">, ] # GraphRegionNoTerminator.traits> { - let summary = "Parallel combining terminator for resultful spat.compute_batch"; + let summary = "Parallel combining terminator for resultful Spatial compute batches"; let regions = (region SizedRegion<1>:$region); @@ -159,6 +192,82 @@ def SpatConcatOp : SpatOp<"concat", []> { let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// Planning +//===----------------------------------------------------------------------===// + +def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> { + let summary = "Structured Conv2D planning op that preserves logical ONNX geometry"; + + let arguments = (ins + SpatTensor:$input, + SpatTensor:$weight, + Optional:$bias, + DenseI64ArrayAttr:$pads, + DenseI64ArrayAttr:$strides, + DenseI64ArrayAttr:$dilations, + I64Attr:$group, + StrAttr:$logicalLayout + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; +} + +def SpatReluPlanOp : SpatOp<"relu_plan", []> { + let summary = "Layout-aware ReLU planning op"; + + let arguments = (ins + SpatTensor:$input, + StrAttr:$logicalLayout + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; +} + +def SpatReconciliatorOp : SpatOp<"reconciliator", []> { + let summary = "Passive logical-to-physical layout selection record"; + + let arguments = (ins + SpatTensor:$input, + StrAttr:$logicalLayout, + StrAttr:$physicalLayout, + DenseI64ArrayAttr:$fragmentOffsets, + DenseI64ArrayAttr:$fragmentSizes, + StrAttr:$indexMap + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; +} + +def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> { + let summary = "Explicit layout conversion or materialization barrier"; + + let arguments = (ins + SpatTensor:$input, + StrAttr:$logicalLayout, + StrAttr:$sourcePhysicalLayout, + StrAttr:$targetPhysicalLayout + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 5c7fdf1..d083848 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -29,11 +29,19 @@ std::optional insertBlockArgument(Region& body, unsigned argIdx, } void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) { - if (auto compute = dyn_cast(op)) { + if (auto compute = dyn_cast(op)) { compute.getProperties().setOperandSegmentSizes({weightCount, inputCount}); return; } - cast(op).getProperties().setOperandSegmentSizes({weightCount, inputCount}); + if (auto compute = dyn_cast(op)) { + compute.getProperties().setOperandSegmentSizes({weightCount, inputCount}); + return; + } + if (auto batch = dyn_cast(op)) { + batch.getProperties().setOperandSegmentSizes({weightCount, inputCount}); + return; + } + cast(op).getProperties().setOperandSegmentSizes({weightCount, inputCount}); } using CrossbarWeightSet = llvm::SetVector, llvm::SmallDenseSet>; @@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) { return weights; } -} // namespace - -std::optional SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); } - -std::optional SpatCompute::getInputArgument(unsigned idx) { - return getBlockArgument(getBody(), getWeights().size() + idx); +template +std::optional getComputeWeightArgument(ComputeOpTy compute, unsigned idx) { + return getBlockArgument(compute.getBody(), idx); } -std::optional> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) { - if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { - auto index = std::distance(getWeights().begin(), existing); - return { - {*existing, *getWeightArgument(index)} - }; +template +std::optional getComputeInputArgument(ComputeOpTy compute, unsigned idx) { + return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx); +} + +template +std::optional> +insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) { + if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) { + auto index = std::distance(compute.getWeights().begin(), existing); + return {{*existing, *getComputeWeightArgument(compute, index)}}; } - unsigned weightCount = getWeights().size(); - unsigned inputCount = getInputs().size(); - getOperation()->insertOperands(idx, ValueRange {weight}); + unsigned weightCount = compute.getWeights().size(); + unsigned inputCount = compute.getInputs().size(); + compute.getOperation()->insertOperands(idx, ValueRange {weight}); setComputeOperandSegmentSizes( - getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); - auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc); + compute.getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); + auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc); if (!blockArg) return std::nullopt; - return std::make_tuple(getOperation()->getOperand(idx), *blockArg); + return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg); } -std::optional> SpatCompute::insertInput(unsigned idx, Value input, Location loc) { - unsigned weightCount = getWeights().size(); - unsigned inputCount = getInputs().size(); - getOperation()->insertOperands(weightCount + idx, ValueRange {input}); +template +std::optional> +insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) { + if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) { + auto index = std::distance(batch.getWeights().begin(), existing); + return {{*existing, *batch.getWeightArgument(index)}}; + } + + unsigned weightCount = batch.getWeights().size(); + unsigned inputCount = batch.getInputs().size(); + batch.getOperation()->insertOperands(idx, ValueRange {weight}); setComputeOperandSegmentSizes( - getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); - auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc); + batch.getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); + + auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc); if (!blockArg) return std::nullopt; - return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); + return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg); } -CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } - -FailureOr> -SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { - if (idx > getNumResults()) - return failure(); - - rewriter.setInsertionPoint(getOperation()); - SmallVector resultTypes(getResultTypes().begin(), getResultTypes().end()); - resultTypes.insert(resultTypes.begin() + idx, type); - auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs()); - newCompute->setAttrs((*this)->getAttrs()); - setComputeOperandSegmentSizes(newCompute.getOperation(), - static_cast(newCompute.getWeights().size()), - static_cast(newCompute.getInputs().size())); - rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end()); - for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) - getResult(oldResultIdx) - .replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); - rewriter.eraseOp(getOperation()); - return std::make_tuple(cast(newCompute.getResult(idx)), newCompute); +template +std::optional> +insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) { + unsigned weightCount = compute.getWeights().size(); + unsigned inputCount = compute.getInputs().size(); + compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input}); + setComputeOperandSegmentSizes( + compute.getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); + auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc); + if (!blockArg) + return std::nullopt; + return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg); } -void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { +template +void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) { if (region.empty()) return; - for (unsigned index = 0; index < getWeights().size(); ++index) - if (auto weightArg = getWeightArgument(index)) + for (unsigned index = 0; index < compute.getWeights().size(); ++index) + if (auto weightArg = compute.getWeightArgument(index)) setNameFn(*weightArg, ("w" + std::to_string(index)).c_str()); - for (unsigned index = 0; index < getInputs().size(); ++index) - if (auto inputArg = getInputArgument(index)) + for (unsigned index = 0; index < compute.getInputs().size(); ++index) + if (auto inputArg = compute.getInputArgument(index)) setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); } -std::optional SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); } +template +FailureOr> +insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + if (idx > compute.getNumResults()) + return failure(); -std::optional SpatComputeBatch::getWeightArgument(unsigned idx) { + rewriter.setInsertionPoint(compute.getOperation()); + SmallVector resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end()); + resultTypes.insert(resultTypes.begin() + idx, type); + auto newCompute = + ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs()); + newCompute->setAttrs(compute->getAttrs()); + setComputeOperandSegmentSizes(newCompute.getOperation(), + static_cast(newCompute.getWeights().size()), + static_cast(newCompute.getInputs().size())); + rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end()); + for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx) + compute.getResult(oldResultIdx) + .replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); + rewriter.eraseOp(compute.getOperation()); + return std::make_tuple(cast(newCompute.getResult(idx)), newCompute); +} + +template +FailureOr> +insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + if (idx > batch.getNumResults()) + return failure(); + + rewriter.setInsertionPoint(batch.getOperation()); + SmallVector resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end()); + resultTypes.insert(resultTypes.begin() + idx, type); + auto newBatch = + ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs()); + newBatch->setAttrs(batch->getAttrs()); + setComputeOperandSegmentSizes(newBatch.getOperation(), + static_cast(newBatch.getWeights().size()), + static_cast(newBatch.getInputs().size())); + rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end()); + if (newBatch.getBody().empty()) { + rewriter.eraseOp(newBatch); + return failure(); + } + auto blockArg = newBatch.getBody().front().insertArgument( + 1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc); + for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx) + batch.getResult(oldResultIdx) + .replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); + rewriter.eraseOp(batch.getOperation()); + return std::make_tuple(cast(newBatch.getResult(idx)), blockArg, newBatch); +} + +} // namespace + +bool isGraphComputeLike(Operation* op) { return isa(op); } + +bool isGraphBatchComputeLike(Operation* op) { return isa(op); } + +bool isScheduledComputeLike(Operation* op) { return isa(op); } + +bool isScheduledBatchComputeLike(Operation* op) { return isa(op); } + +bool isAnySpatialComputeLike(Operation* op) { + return isa(op); +} + +bool isAnySpatialComputeBatchLike(Operation* op) { return isa(op); } + +std::optional SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); } +std::optional SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); } +std::optional> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) { + return insertComputeWeight(*this, idx, weight, loc); +} +std::optional> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) { + return insertComputeInput(*this, idx, input, loc); +} +CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } +FailureOr> +SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + return insertComputeOutput(*this, rewriter, idx, type, loc); +} +void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + setComputeAsmBlockArgumentNames(*this, region, setNameFn); +} + +std::optional SpatScheduledCompute::getWeightArgument(unsigned idx) { + return getComputeWeightArgument(*this, idx); +} +std::optional SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); } +std::optional> +SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) { + return insertComputeWeight(*this, idx, weight, loc); +} +std::optional> +SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) { + return insertComputeInput(*this, idx, input, loc); +} +CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } +FailureOr> +SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + return insertComputeOutput(*this, rewriter, idx, type, loc); +} +void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + setComputeAsmBlockArgumentNames(*this, region, setNameFn); +} + +std::optional SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); } +std::optional SpatGraphComputeBatch::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), 1 + idx); } - -std::optional SpatComputeBatch::getInputArgument(unsigned idx) { +std::optional SpatGraphComputeBatch::getInputArgument(unsigned idx) { return getBlockArgument(getBody(), 1 + getWeights().size() + idx); } - -std::optional SpatComputeBatch::getOutputArgument(unsigned idx) { +std::optional SpatGraphComputeBatch::getOutputArgument(unsigned idx) { return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); } - std::optional> -SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { - if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { - auto index = std::distance(getWeights().begin(), existing); - return { - {*existing, *getWeightArgument(index)} - }; - } - - unsigned weightCount = getWeights().size(); - unsigned inputCount = getInputs().size(); - getOperation()->insertOperands(idx, ValueRange {weight}); - setComputeOperandSegmentSizes( - getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); - auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc); - if (!blockArg) - return std::nullopt; - return std::make_tuple(getOperation()->getOperand(idx), *blockArg); +SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { + return insertComputeBatchWeight(*this, idx, weight, loc); } - -std::optional> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) { +std::optional> +SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) { unsigned weightCount = getWeights().size(); unsigned inputCount = getInputs().size(); getOperation()->insertOperands(weightCount + idx, ValueRange {input}); @@ -167,52 +264,68 @@ std::optional> SpatComputeBatch::insertInput(un return std::nullopt; return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); } - -CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } - -FailureOr> -SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { - if (idx > getNumResults()) - return failure(); - - rewriter.setInsertionPoint(getOperation()); - SmallVector resultTypes(getResultTypes().begin(), getResultTypes().end()); - resultTypes.insert(resultTypes.begin() + idx, type); - auto newBatch = - SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs()); - newBatch->setAttrs((*this)->getAttrs()); - setComputeOperandSegmentSizes(newBatch.getOperation(), - static_cast(newBatch.getWeights().size()), - static_cast(newBatch.getInputs().size())); - rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end()); - if (newBatch.getBody().empty()) { - rewriter.eraseOp(newBatch); - return failure(); - } - auto blockArg = newBatch.getBody().front().insertArgument( - 1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc); - for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) - getResult(oldResultIdx) - .replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); - rewriter.eraseOp(getOperation()); - return std::make_tuple(cast(newBatch.getResult(idx)), blockArg, newBatch); +CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } +FailureOr> +SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + return insertComputeBatchOutput(*this, rewriter, idx, type, loc); } - -void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { +void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { if (region.empty()) return; if (auto laneArg = getLaneArgument()) setNameFn(*laneArg, "lane"); + setComputeAsmBlockArgumentNames(*this, region, setNameFn); + for (unsigned index = 0; index < getNumResults(); ++index) { + auto outputArg = getOutputArgument(index); + if (!outputArg) + continue; + if (index == 0) { + setNameFn(*outputArg, "out"); + continue; + } + setNameFn(*outputArg, ("out" + std::to_string(index)).c_str()); + } +} - for (unsigned index = 0; index < getWeights().size(); ++index) - if (auto weightArg = getWeightArgument(index)) - setNameFn(*weightArg, ("w" + std::to_string(index)).c_str()); - - for (unsigned index = 0; index < getInputs().size(); ++index) - if (auto inputArg = getInputArgument(index)) - setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); +std::optional SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); } +std::optional SpatScheduledComputeBatch::getWeightArgument(unsigned idx) { + return getBlockArgument(getBody(), 1 + idx); +} +std::optional SpatScheduledComputeBatch::getInputArgument(unsigned idx) { + return getBlockArgument(getBody(), 1 + getWeights().size() + idx); +} +std::optional SpatScheduledComputeBatch::getOutputArgument(unsigned idx) { + return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); +} +std::optional> +SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { + return insertComputeBatchWeight(*this, idx, weight, loc); +} +std::optional> +SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) { + unsigned weightCount = getWeights().size(); + unsigned inputCount = getInputs().size(); + getOperation()->insertOperands(weightCount + idx, ValueRange {input}); + setComputeOperandSegmentSizes( + getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); + auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc); + if (!blockArg) + return std::nullopt; + return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); +} +CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } +FailureOr> +SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + return insertComputeBatchOutput(*this, rewriter, idx, type, loc); +} +void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + if (region.empty()) + return; + if (auto laneArg = getLaneArgument()) + setNameFn(*laneArg, "lane"); + setComputeAsmBlockArgumentNames(*this, region, setNameFn); for (unsigned index = 0; index < getNumResults(); ++index) { auto outputArg = getOutputArgument(index); if (!outputArg) @@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) { builder.createBlock(bodyRegion); } -OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } +OpResult SpatInParallelOp::getParentResult(int64_t idx) { + Operation* parent = getOperation()->getParentOp(); + assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent"); + return parent->getResult(idx); +} llvm::iterator_range SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); } diff --git a/src/PIM/Dialect/Spatial/SpatialOps.hpp b/src/PIM/Dialect/Spatial/SpatialOps.hpp index 7dc89fd..eff0eaa 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.hpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.hpp @@ -26,3 +26,19 @@ #define GET_OP_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc" + +namespace onnx_mlir { +namespace spatial { + +bool isGraphComputeLike(mlir::Operation* op); +bool isGraphBatchComputeLike(mlir::Operation* op); +bool isScheduledComputeLike(mlir::Operation* op); +bool isScheduledBatchComputeLike(mlir::Operation* op); +bool isAnySpatialComputeLike(mlir::Operation* op); +bool isAnySpatialComputeBatchLike(mlir::Operation* op); + +using SpatCompute = SpatGraphCompute; +using SpatComputeBatch = SpatGraphComputeBatch; + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 18e80ca..63c9504 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser, return success(); } +template +void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) { + SmallVector weightArgs; + weightArgs.reserve(op.getWeights().size()); + for (unsigned index = 0; index < op.getWeights().size(); ++index) { + auto weightArg = op.getWeightArgument(index); + if (!weightArg) + return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); + weightArgs.push_back(*weightArg); + } + SmallVector inputArgs; + inputArgs.reserve(op.getInputs().size()); + for (unsigned index = 0; index < op.getInputs().size(); ++index) { + auto inputArg = op.getInputArgument(index); + if (!inputArg) + return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); + inputArgs.push_back(*inputArg); + } + + printer << " "; + printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square); + printer << " "; + printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren); + + if (auto coreIdAttr = op->template getAttrOfType(onnx_mlir::kCoreIdAttrName)) + printer << " coreId " << coreIdAttr.getInt(); + printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size(); + + printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + + printer << " : "; + printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square); + printer << " "; + printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren); + printer << " -> "; + printCompressedTypeSequence(printer, op.getResultTypes()); + printer << " "; + printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); +} + +template +ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) { + SmallVector weightArgs; + SmallVector regionArgs; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector outputTypes; + int32_t crossbarWeightCount = 0; + int32_t coreId = 0; + + if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) + return failure(); + + SmallVector inputArgs; + if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) + return failure(); + + bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); + if (hasCoreId && parser.parseInteger(coreId)) + return failure(); + + bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); + if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) + return failure(); + (void) crossbarWeightCount; + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (weightArgs.size() != weights.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (inputArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) + return parser.emitError(parser.getCurrentLocation(), + "coreId cannot be specified both positionally and in attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreId) + result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + Region* body = result.addRegion(); + applyArgumentTypes(weightTypes, weightArgs); + applyArgumentTypes(inputTypes, inputArgs); + llvm::append_range(regionArgs, weightArgs); + llvm::append_range(regionArgs, inputArgs); + return parser.parseRegion(*body, regionArgs); +} + +template +void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) { + auto laneArg = op.getLaneArgument(); + SmallVector weightArgs; + weightArgs.reserve(op.getWeights().size()); + for (unsigned index = 0; index < op.getWeights().size(); ++index) { + auto weightArg = op.getWeightArgument(index); + if (!weightArg) + return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); + weightArgs.push_back(*weightArg); + } + SmallVector inputArgs; + inputArgs.reserve(op.getInputs().size()); + for (unsigned index = 0; index < op.getInputs().size(); ++index) { + auto inputArg = op.getInputArgument(index); + if (!inputArg) + return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); + inputArgs.push_back(*inputArg); + } + + SmallVector outputArgs; + if (!laneArg) + return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); + if (op.getNumResults() != 0) { + outputArgs.reserve(op.getNumResults()); + for (unsigned index = 0; index < op.getNumResults(); ++index) { + auto outputArg = op.getOutputArgument(index); + if (!outputArg) + return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); + outputArgs.push_back(*outputArg); + } + } + + printer << " "; + printer.printOperand(*laneArg); + printer << " = 0 to " << op.getLaneCount(); + printer << " "; + printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square); + printer << " "; + printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren); + if (op.getNumResults() != 0) { + printer << " shared_outs"; + printBlockArgumentList(printer, outputArgs); + } + printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size(); + if (auto coreIdsAttr = op->template getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { + printer << " coreIds "; + printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); + } + printer.printOptionalAttrDict( + op->getAttrs(), + {op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); + printer << " : "; + printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square); + printer << " "; + printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren); + printer << " -> "; + printCompressedTypeSequence(printer, op.getResultTypes()); + printer << " "; + printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); +} + +template +ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) { + int64_t lowerBound = 0; + int32_t laneCount = 0; + OpAsmParser::Argument laneArg; + SmallVector weightArgs; + SmallVector inputArgs; + SmallVector outputArgs; + SmallVector regionArgs; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector outputTypes; + int32_t crossbarWeightCount = 0; + SmallVector coreIds; + + if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) + || parser.parseKeyword("to") || parser.parseInteger(laneCount)) + return failure(); + if (lowerBound != 0) + return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound"); + if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) + return failure(); + if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("shared_outs"))) + if (parseBlockArgumentList(parser, outputArgs)) + return failure(); + bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); + if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) + return failure(); + bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); + if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) + return failure(); + (void) crossbarWeightCount; + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (weightArgs.size() != weights.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (inputArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (outputArgs.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "number of shared output bindings and result types must match"); + if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) + return parser.emitError(parser.getCurrentLocation(), + "coreIds cannot be specified both positionally and in attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreIds) + result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + Region* body = result.addRegion(); + applyBatchRegionArgumentTypes( + inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder()); + return parser.parseRegion(*body, regionArgs); +} + } // namespace void SpatYieldOp::print(OpAsmPrinter& printer) { @@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } -void SpatCompute::print(OpAsmPrinter& printer) { - SmallVector weightArgs; - weightArgs.reserve(getWeights().size()); - for (unsigned index = 0; index < getWeights().size(); ++index) { - auto weightArg = getWeightArgument(index); - if (!weightArg) - return printer.printGenericOp(getOperation(), /*printOpName=*/false); - weightArgs.push_back(*weightArg); - } - SmallVector inputArgs; - inputArgs.reserve(getInputs().size()); - for (unsigned index = 0; index < getInputs().size(); ++index) { - auto inputArg = getInputArgument(index); - if (!inputArg) - return printer.printGenericOp(getOperation(), /*printOpName=*/false); - inputArgs.push_back(*inputArg); - } - - printer << " "; - printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); - printer << " "; - printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); - - if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - printer << " coreId " << coreIdAttr.getInt(); - printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size(); - - printer.printOptionalAttrDict((*this)->getAttrs(), - {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); - - printer << " : "; - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); - printer << " "; - printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); - printer << " -> "; - printCompressedTypeSequence(printer, getResultTypes()); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); } +ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) { + return parseComputeLikeOp(parser, result); } - -ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { - SmallVector weightArgs; - SmallVector regionArgs; - SmallVector weights; - SmallVector inputs; - SmallVector weightTypes; - SmallVector inputTypes; - SmallVector outputTypes; - int32_t crossbarWeightCount = 0; - int32_t coreId = 0; - - if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) - return failure(); - - SmallVector inputArgs; - if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) - return failure(); - - bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); - if (hasCoreId && parser.parseInteger(coreId)) - return failure(); - - bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); - if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) - return failure(); - (void) crossbarWeightCount; - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedRepeatedList( - parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) - || parseCompressedRepeatedList( - parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) - || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) - return failure(); - - if (weights.size() != weightTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); - if (weightArgs.size() != weights.size()) - return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (inputArgs.size() != inputs.size()) - return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); - if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) - return parser.emitError(parser.getCurrentLocation(), - "coreId cannot be specified both positionally and in attr-dict"); - - auto& builder = parser.getBuilder(); - result.addAttribute( - "operandSegmentSizes", - builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); - if (hasCoreId) - result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); - - if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) - || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - - Region* body = result.addRegion(); - applyArgumentTypes(weightTypes, weightArgs); - applyArgumentTypes(inputTypes, inputArgs); - llvm::append_range(regionArgs, weightArgs); - llvm::append_range(regionArgs, inputArgs); - return parser.parseRegion(*body, regionArgs); +void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); } +ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) { + return parseComputeLikeOp(parser, result); } - -void SpatComputeBatch::print(OpAsmPrinter& printer) { - auto laneArg = getLaneArgument(); - SmallVector weightArgs; - weightArgs.reserve(getWeights().size()); - for (unsigned index = 0; index < getWeights().size(); ++index) { - auto weightArg = getWeightArgument(index); - if (!weightArg) - return printer.printGenericOp(getOperation(), /*printOpName=*/false); - weightArgs.push_back(*weightArg); - } - SmallVector inputArgs; - inputArgs.reserve(getInputs().size()); - for (unsigned index = 0; index < getInputs().size(); ++index) { - auto inputArg = getInputArgument(index); - if (!inputArg) - return printer.printGenericOp(getOperation(), /*printOpName=*/false); - inputArgs.push_back(*inputArg); - } - - SmallVector outputArgs; - if (!laneArg) - return printer.printGenericOp(getOperation(), /*printOpName=*/false); - if (getNumResults() != 0) { - outputArgs.reserve(getNumResults()); - for (unsigned index = 0; index < getNumResults(); ++index) { - auto outputArg = getOutputArgument(index); - if (!outputArg) - return printer.printGenericOp(getOperation(), /*printOpName=*/false); - outputArgs.push_back(*outputArg); - } - } - - printer << " "; - printer.printOperand(*laneArg); - printer << " = 0 to " << getLaneCount(); - - printer << " "; - printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); - printer << " "; - printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); - - if (getNumResults() != 0) { - printer << " shared_outs"; - printBlockArgumentList(printer, outputArgs); - } - - printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size(); - - if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { - printer << " coreIds "; - printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); - } - - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); - - printer << " : "; - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); - printer << " "; - printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); - printer << " -> "; - printCompressedTypeSequence(printer, getResultTypes()); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); } +ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) { + return parseComputeBatchLikeOp(parser, result); } - -ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { - int64_t lowerBound = 0; - int32_t laneCount = 0; - OpAsmParser::Argument laneArg; - SmallVector weightArgs; - SmallVector inputArgs; - SmallVector outputArgs; - SmallVector regionArgs; - SmallVector weights; - SmallVector inputs; - SmallVector weightTypes; - SmallVector inputTypes; - SmallVector outputTypes; - int32_t crossbarWeightCount = 0; - SmallVector coreIds; - - if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) - || parser.parseKeyword("to") || parser.parseInteger(laneCount)) - return failure(); - if (lowerBound != 0) - return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound"); - - if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) - return failure(); - - if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("shared_outs"))) - if (parseBlockArgumentList(parser, outputArgs)) - return failure(); - - bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); - if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) - return failure(); - - bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); - if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) - return failure(); - (void) crossbarWeightCount; - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes) - || parseCompressedRepeatedList( - parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) - || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) - return failure(); - - if (weights.size() != weightTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); - if (weightArgs.size() != weights.size()) - return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (inputArgs.size() != inputs.size()) - return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); - if (outputArgs.size() != outputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), - "number of shared output bindings and result types must match"); - if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) - return parser.emitError(parser.getCurrentLocation(), - "coreIds cannot be specified both positionally and in attr-dict"); - - auto& builder = parser.getBuilder(); - result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); - result.addAttribute( - "operandSegmentSizes", - builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); - if (hasCoreIds) - result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); - - if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) - || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - - Region* body = result.addRegion(); - applyBatchRegionArgumentTypes( - inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder()); - return parser.parseRegion(*body, regionArgs); +void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); } +ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) { + return parseComputeBatchLikeOp(parser, result); } void SpatInParallelOp::print(OpAsmPrinter& printer) { diff --git a/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp b/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp index 9abdba7..278dffb 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp @@ -10,8 +10,9 @@ using namespace mlir; namespace onnx_mlir { namespace spatial { -LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { - Block& block = getBody().front(); +template +LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { + Block& block = compute.getBody().front(); if (!llvm::hasSingleElement(block)) return failure(); @@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m for (Value yieldedValue : yieldOp.getOperands()) { if (auto blockArg = dyn_cast(yieldedValue)) { if (blockArg.getOwner() == &block) { - results.push_back(getOperand(blockArg.getArgNumber())); + results.push_back(compute.getOperand(blockArg.getArgNumber())); continue; } } @@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m return success(); } +LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { + return foldComputeLike(*this, results); +} + +LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { + return foldComputeLike(*this, results); +} + } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 02bc62d..f4420c2 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -35,7 +35,8 @@ static FailureOr> getWeightShapeForWeightedOp(Value weight) { return shapedType.getShape(); } -static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { +template +static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) { if (batchOp.getNumResults() == 0) return false; auto blockArg = dyn_cast(value); @@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) return success(); } +static bool isStaticIndexExpr(Value value) { + if (matchConstantIndexValue(value)) + return true; + + auto affineApply = value.getDefiningOp(); + if (affineApply) { + if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap())) + return false; + return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr); + } + + if (auto addOp = value.getDefiningOp()) + return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs()); + + if (auto mulOp = value.getDefiningOp()) + return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs()); + + return false; +} + static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { - if (value == laneArg || matchConstantIndexValue(value)) + if (value == laneArg || isStaticIndexExpr(value)) return true; auto affineApply = value.getDefiningOp(); @@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { } auto addOp = value.getDefiningOp(); - if (!addOp) + if (addOp) + return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs())) + || (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs())); + + auto mulOp = value.getDefiningOp(); + if (!mulOp) return false; - return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs())) - || (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs())); + return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs())) + || (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs())); } static LogicalResult @@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)) continue; - InFlightDiagnostic diagnostic = ownerOp->emitOpError() - << kind << " body may only directly reference external constants"; + InFlightDiagnostic diagnostic = + ownerOp->emitOpError() << kind << " body may not capture external values"; diagnostic.attachNote(op->getLoc()) - << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); + << "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#" + << operand.getOperandNumber() << " type=" << value.getType() + << " category=" << (isa(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index" + : "scalar")); + if (Operation* definingOp = value.getDefiningOp()) + diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'"; + else if (auto blockArg = dyn_cast(value)) + diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc()) + << "value is block argument #" << blockArg.getArgNumber() << " of '" + << blockArg.getOwner()->getParentOp()->getName() << "'"; hasFailure = true; } }); return success(!hasFailure); } -static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) { +template +static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) { if (batchOp.getNumResults() == 0) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) @@ -344,144 +380,266 @@ LogicalResult SpatConcatOp::verify() { return success(); } -LogicalResult verifyComputeResultsUses(Operation* op) { - if (!isa(op)) - return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation"); - if (!llvm::all_of(op->getResults(), [](Value result) { - return llvm::all_of(result.getUsers(), [](Operation* op) { - return !(op->getParentOfType() || op->getParentOfType()); - }); - })) { - return op->emitError("ComputeResult used directly inside another Compute"); +static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; } + +static bool isKnownPhysicalLayout(StringRef layout) { + return layout == "dense_nchw" || layout == "nchw_row_strip"; +} + +static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) { + auto inputType = dyn_cast(input.getType()); + auto outputType = dyn_cast(output.getType()); + if (!inputType || !outputType) + return op->emitOpError() << kind << " requires ranked tensor input and output types"; + if (inputType.getElementType() != outputType.getElementType()) + return op->emitOpError() << kind << " requires matching input/output element types"; + return success(); +} + +LogicalResult SpatConv2DPlanOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto weightType = dyn_cast(getWeight().getType()); + auto outputType = dyn_cast(getOutput().getType()); + if (!inputType || !weightType || !outputType) + return emitError("requires ranked tensor input, weight, and output"); + if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4) + return emitError("requires rank-4 input, weight, and output tensors"); + if (!isKnownLogicalLayout(getLogicalLayout())) + return emitError("requires a known logical layout"); + if (getPads().size() != 4) + return emitError("requires exactly four pad values"); + if (getStrides().size() != 2) + return emitError("requires exactly two stride values"); + if (getDilations().size() != 2) + return emitError("requires exactly two dilation values"); + if (getGroup() < 1) + return emitError("requires group >= 1"); + if (inputType.getElementType() != weightType.getElementType() + || inputType.getElementType() != outputType.getElementType()) { + return emitError("requires matching input, weight, and output element types"); + } + if (getBias()) { + auto biasType = dyn_cast(getBias().getType()); + if (!biasType) + return emitError("requires ranked tensor bias type"); + if (biasType.getElementType() != outputType.getElementType()) + return emitError("requires bias element type to match output element type"); } return success(); } -LogicalResult SpatCompute::verify() { - auto& block = getBody().front(); - unsigned expectedArgCount = getWeights().size() + getInputs().size(); - if (block.getNumArguments() != expectedArgCount) - return emitError("compute body must have weight and input block arguments"); +LogicalResult SpatReluPlanOp::verify() { + if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan"))) + return failure(); + if (!isKnownLogicalLayout(getLogicalLayout())) + return emitError("requires a known logical layout"); + return success(); +} - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { - auto blockArg = getWeightArgument(weightIndex); - if (!blockArg || blockArg->getType() != weight.getType()) - return emitError("compute weight block argument types must match weight operand types exactly"); +LogicalResult SpatReconciliatorOp::verify() { + if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator"))) + return failure(); + if (!isKnownLogicalLayout(getLogicalLayout())) + return emitError("requires a known logical layout"); + if (!isKnownPhysicalLayout(getPhysicalLayout())) + return emitError("requires a known physical layout"); + + auto logicalType = dyn_cast(getOutput().getType()); + if (!logicalType) + return emitError("requires ranked tensor output"); + + auto offsets = getFragmentOffsets(); + auto sizes = getFragmentSizes(); + if (offsets.size() != sizes.size()) + return emitError("fragment offset and size arrays must have the same length"); + if (offsets.empty()) + return success(); + + int64_t rank = logicalType.getRank(); + if (rank <= 0 || offsets.size() % rank != 0) + return emitError("fragment metadata must be a whole number of rank-sized fragments"); + + ArrayRef shape = logicalType.getShape(); + for (int64_t index = 0; index < static_cast(offsets.size()); ++index) { + int64_t dim = index % rank; + int64_t offset = offsets[index]; + int64_t size = sizes[index]; + if (offset < 0 || size < 0) + return emitError("fragment offsets and sizes must be non-negative"); + int64_t logicalDim = shape[dim]; + if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim) + return emitError("fragment bounds must stay within the logical tensor shape"); } - for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { - auto blockArg = getInputArgument(inputIndex); + + return success(); +} + +LogicalResult SpatMaterializeLayoutOp::verify() { + if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout"))) + return failure(); + if (!isKnownLogicalLayout(getLogicalLayout())) + return emitError("requires a known logical layout"); + if (!isKnownPhysicalLayout(getSourcePhysicalLayout())) + return emitError("requires a known source physical layout"); + if (!isKnownPhysicalLayout(getTargetPhysicalLayout())) + return emitError("requires a known target physical layout"); + return success(); +} + +LogicalResult verifyComputeResultsUses(Operation* op) { + if (!isAnySpatialComputeLike(op)) + return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation"); + if (!llvm::all_of(op->getResults(), [](Value result) { + return llvm::all_of(result.getUsers(), [](Operation* op) { + return !isAnySpatialComputeLike(op->getParentOp()); + }); + })) { + return op->emitError("compute result used directly inside another Spatial compute body"); + } + return success(); +} + +template +LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) { + auto& block = compute.getBody().front(); + unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size(); + if (block.getNumArguments() != expectedArgCount) + return compute.emitOpError("compute body must have weight and input block arguments"); + + for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) { + auto blockArg = compute.getWeightArgument(weightIndex); + if (!blockArg || blockArg->getType() != weight.getType()) + return compute.emitOpError("compute weight block argument types must match weight operand types exactly"); + } + for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) { + auto blockArg = compute.getInputArgument(inputIndex); if (!blockArg || blockArg->getType() != input.getType()) - return emitError("compute input block argument types must match input operand types exactly"); + return compute.emitOpError("compute input block argument types must match input operand types exactly"); } if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) - return emitError("ComputeOp must have a single yield operation"); + return compute.emitOpError("ComputeOp must have a single yield operation"); - auto resultTypes = getResultTypes(); + auto resultTypes = compute.getResultTypes(); auto yieldTypes = yieldOp->getOperandTypes(); if (resultTypes.size() != yieldTypes.size()) - return emitError("ComputeOp must have same number of results as yieldOp operands"); + return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands"); for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { auto resultType = std::get<0>(it); auto yieldType = std::get<1>(it); if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) - return emitError("ComputeOp output must be of the same type as yieldOp operand"); + return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand"); if (auto resultRankedType = dyn_cast(resultType)) { if (auto yieldRankedType = dyn_cast(yieldType)) { if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) - return emitError("ComputeOp output must have the same encoding as yieldOp operand"); + return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand"); } else { - return emitError("ComputeOp output has an encoding while yieldOp operand does not have one"); + return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one"); } } else if (dyn_cast(yieldType)) { - return emitError("ComputeOp output must not have an encoding if yieldOp operand has one"); + return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one"); } } } - for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) - if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty()) - return emitError("ComputeOp block argument is not used"); - if (failed(verifyStaticWeights(*this, "compute"))) + for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex) + if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty()) + return compute.emitOpError("ComputeOp block argument is not used"); + if (failed(verifyStaticWeights(compute, opName))) return failure(); - if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) + if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName))) return failure(); - if (failed(verifyComputeResultsUses(this->getOperation()))) + if (failed(verifyComputeResultsUses(compute.getOperation()))) return failure(); return success(); } -LogicalResult SpatComputeBatch::verify() { - int32_t count = getLaneCount(); +LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); } + +LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); } + +template +LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) { + int32_t count = batch.getLaneCount(); if (count <= 0) - return emitError("laneCount must be positive"); + return batch.emitOpError("laneCount must be positive"); auto laneCountSz = static_cast(count); - if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) { + if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); if (!coreIdsAttr) - return emitError("compute_batch coreIds attribute must be a dense i32 array"); + return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array"); if (coreIdsAttr.size() != static_cast(laneCountSz)) - return emitError("compute_batch coreIds array length must match laneCount"); + return batch.emitOpError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; })) - return emitError("compute_batch coreIds values must be non-negative"); + return batch.emitOpError("compute_batch coreIds values must be non-negative"); DenseSet seenCoreIds; for (int32_t coreId : coreIdsAttr.asArrayRef()) if (!seenCoreIds.insert(coreId).second) - return emitError("compute_batch coreIds values must be unique"); + return batch.emitOpError("compute_batch coreIds values must be unique"); } - Block& block = getBody().front(); + Block& block = batch.getBody().front(); if (block.getNumArguments() == 0) - return emitError("compute_batch body must have exactly one lane block argument"); - unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults(); + return batch.emitOpError("compute_batch body must have exactly one lane block argument"); + unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults(); if (block.getNumArguments() != expectedArgCount) - return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results"); - auto laneArg = getLaneArgument(); + return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results"); + auto laneArg = batch.getLaneArgument(); if (!laneArg || !laneArg->getType().isIndex()) - return emitError("compute_batch first block argument must have index type"); + return batch.emitOpError("compute_batch first block argument must have index type"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { - auto blockArg = getWeightArgument(weightIndex); + for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) { + auto blockArg = batch.getWeightArgument(weightIndex); if (!blockArg || blockArg->getType() != weight.getType()) - return emitError("compute_batch weight block argument types must match weight operand types exactly"); + return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly"); } - for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { - auto blockArg = getInputArgument(inputIndex); + for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { + auto blockArg = batch.getInputArgument(inputIndex); if (!blockArg || blockArg->getType() != input.getType()) - return emitError("compute_batch input block argument types must match input operand types exactly"); + return batch.emitOpError("compute_batch input block argument types must match input operand types exactly"); } - for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) { - auto blockArg = getOutputArgument(resultIndex); + for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) { + auto blockArg = batch.getOutputArgument(resultIndex); if (!blockArg || blockArg->getType() != resultType) - return emitError("compute_batch output block argument types must match result types exactly"); + return batch.emitOpError("compute_batch output block argument types must match result types exactly"); } - if (failed(verifyComputeResultsUses(this->getOperation()))) + if (failed(verifyComputeResultsUses(batch.getOperation()))) return failure(); - if (failed(verifyStaticWeights(*this, "compute_batch"))) + if (failed(verifyStaticWeights(batch, opName))) return failure(); - if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch"))) + if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName))) return failure(); - return verifyBatchBody(*this, block); + return verifyBatchBody(batch, block); +} + +LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); } + +LogicalResult SpatScheduledComputeBatch::verify() { + return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch"); } LogicalResult SpatInParallelOp::verify() { - auto batchOp = getOperation()->getParentOfType(); - if (!batchOp) - return emitOpError("expected spat.compute_batch parent"); - if (batchOp.getNumResults() == 0) + Operation* parent = getOperation()->getParentOp(); + if (!isAnySpatialComputeBatchLike(parent)) + return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent"); + if (parent->getNumResults() == 0) return emitOpError("requires a resultful spat.compute_batch parent"); - auto laneArg = batchOp.getLaneArgument(); + std::optional laneArg; + if (auto graphBatch = dyn_cast(parent)) + laneArg = graphBatch.getLaneArgument(); + else + laneArg = cast(parent).getLaneArgument(); if (!laneArg) return emitOpError("expected compute_batch lane block argument"); for (Operation& op : getRegion().front().getOperations()) { @@ -494,7 +652,10 @@ LogicalResult SpatInParallelOp::verify() { MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations(); for (OpOperand& destination : destinations) - if (!isBatchOutputArgument(batchOp, destination.get())) + if ((isa(parent) + && !isBatchOutputArgument(cast(parent), destination.get())) + || (isa(parent) + && !isBatchOutputArgument(cast(parent), destination.get()))) return op.emitOpError("may only insert into a compute_batch output block argument"); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 1510c1d..82dbf03 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -14,14 +14,17 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include #include #include +#include #include #include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" @@ -323,11 +326,74 @@ struct AffineProjectedInputSliceMatch { struct MaterializerState; +struct PendingProjectedHostReceiveGroup { + Value originalOutput; + ClassId ownerClassId = 0; + RankedTensorType fragmentType; + SmallVector keys; + MessageVector messages; + Location loc; +}; + +struct PendingScalarReceiveRecord { + PendingScalarReceiveRecord(ArrayRef keys, + ClassId targetClassId, + Type receiveType, + const MessageVector& messages, + Location loc) + : targetClassId(targetClassId), + receiveType(receiveType), + messages(messages), + loc(loc) { + this->keys.append(keys.begin(), keys.end()); + } + + SmallVector keys; + ClassId targetClassId = 0; + Type receiveType; + MessageVector messages; + Location loc; + bool materialized = false; + Value value; +}; + FailureOr materializeProjectedExtractReplacement(MaterializerState& state, MaterializedClass& targetClass, tensor::ExtractSliceOp extract, const ProjectedExtractReplacement& replacement, - std::optional projectionSlotIndex); + std::optional projectionSlotIndex, + IRMapping* mapper = nullptr); +FailureOr rematerializeTensorValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + IRMapping* mapper = nullptr); +FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + std::optional producer = std::nullopt, + IRMapping* mapper = nullptr); +FailureOr localizeMaterializedClassOperand(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper = nullptr); +LogicalResult localizeCapturesInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& clonedOp, + IRMapping* mapper = nullptr); +bool requiresConstantProjectionSlotIndex(MaterializerState& state, + MaterializedClass& targetClass, + Operation* sourceOp); +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane); class AvailableValueStore { public: @@ -421,6 +487,12 @@ struct MaterializerState { DenseMap> projectedExtractReplacements; AvailableValueStore availableValues; DenseMap hostReplacements; + DenseMap hostOutputOwners; + SmallVector pendingProjectedHostReceives; + SmallVector pendingScalarReceives; + DenseMap, ProducerKeyInfo> pendingScalarReceiveLookup; + DenseMap firstLateCommunicationOps; + int64_t nextCommunicationTraceId = 0; DenseSet oldComputeOps; MaterializerState(func::FuncOp func, @@ -536,6 +608,32 @@ std::optional getContiguousProducerRangeForKeys(ArrayRef getPhysicallyContiguousProducerRangeForKeys(ArrayRef keys) { + if (keys.empty()) + return std::nullopt; + + ProducerKey first = keys.front(); + auto batch = dyn_cast_or_null(first.instance.op); + if (!batch || first.instance.laneCount == 0) + return std::nullopt; + + uint32_t laneStart = first.instance.laneStart; + uint32_t nextLane = laneStart; + for (ProducerKey key : keys) { + if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) + return std::nullopt; + if (key.instance.laneStart != nextLane) + return std::nullopt; + nextLane += key.instance.laneCount; + } + + uint32_t laneCount = nextLane - laneStart; + if (laneStart + laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + + return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); +} + WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) { return {sourceOp, resultIndex, classId}; } @@ -591,11 +689,6 @@ collectProducerKeysForDestinations(Value value, std::optional l return keys; } - if (logicalConsumer && isa(logicalConsumer->op)) { - keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); - return keys; - } - return {}; } @@ -622,11 +715,6 @@ collectProducerKeysForDestinations(Value value, std::optional l return {}; if (batch.getNumResults() != 0) { - if (logicalConsumer && isa(logicalConsumer->op)) { - keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); - return keys; - } - for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); return keys; @@ -659,9 +747,6 @@ std::optional getInputRequestProducerKey(Value value, if (std::optional lane = getConstantFirstSliceOffset(extract)) return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); - if (logicalConsumer && isa(logicalConsumer->op)) - return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); - return std::nullopt; } @@ -686,11 +771,8 @@ std::optional getInputRequestProducerKey(Value value, if (!result) return std::nullopt; - if (batch.getNumResults() != 0) { - if (logicalConsumer && isa(logicalConsumer->op)) - return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); + if (batch.getNumResults() != 0) return getWholeBatchProducerKey(batch, result.getResultNumber()); - } return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; } @@ -881,23 +963,66 @@ LogicalResult forEachLogicalConsumerInMaterializationOrder( return success(); } +bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps); + LogicalResult collectHostOutputs(MaterializerState& state) { DenseSet seenOutputs; + SmallVector orderedOutputs; + DenseMap preferredOwners; + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { auto cpuIt = state.schedule.computeToCpuMap.find(instance); if (cpuIt == state.schedule.computeToCpuMap.end()) return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - MaterializedClass& materializedClass = state.classes[state.cpuToClass.lookup(cpuIt->second)]; + ClassId classId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& materializedClass = state.classes[classId]; for (Value output : getComputeInstanceOutputValuesCached(state, instance)) { - if (!hasLiveExternalUseCached(state, output) || !seenOutputs.insert(output).second) + if (!hasLiveExternalUseCached(state, output)) continue; - materializedClass.hostOutputToResultIndex[output] = materializedClass.hostOutputs.size(); - materializedClass.hostOutputs.push_back(output); + if (seenOutputs.insert(output).second) { + orderedOutputs.push_back(output); + preferredOwners[output] = classId; + continue; + } + + auto batch = dyn_cast_or_null(output.getDefiningOp()); + if (!batch || batch.getNumResults() == 0) + continue; + + ClassId currentOwner = preferredOwners.lookup(output); + bool terminalHost = isTerminalHostBatchOutput(output, state.oldComputeOps); + if (terminalHost) { + // Terminal resultful batch outputs are still published through scalar + // host-output slots unless the materialized batch class owns the output + // directly. Selecting an arbitrary batch class as the host owner would + // require a projection-aware batch publication path, which the + // materializer does not currently implement. + if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + preferredOwners[output] = classId; + continue; + } + + if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + preferredOwners[output] = classId; } } + for (MaterializedClass& materializedClass : state.classes) { + materializedClass.hostOutputs.clear(); + materializedClass.hostOutputToResultIndex.clear(); + } + state.hostOutputOwners.clear(); + + for (Value output : orderedOutputs) { + ClassId ownerClassId = preferredOwners.lookup(output); + MaterializedClass& ownerClass = state.classes[ownerClassId]; + ownerClass.hostOutputToResultIndex[output] = ownerClass.hostOutputs.size(); + ownerClass.hostOutputs.push_back(output); + state.hostOutputOwners[output] = ownerClassId; + } + return success(); } @@ -925,7 +1050,7 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) { resultTypes.push_back(output.getType()); if (!materializedClass.isBatch) { - auto compute = SpatCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); + auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); compute.getProperties().setOperandSegmentSizes({0, 0}); auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); @@ -956,8 +1081,7 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) { state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); if (failed(batchLaneCountAttr)) return failure(); - auto batch = SpatComputeBatch::create( - state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); + auto batch = SpatScheduledComputeBatch::create(state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); batch.getProperties().setOperandSegmentSizes({0, 0}); auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); if (failed(coreIds)) @@ -991,14 +1115,14 @@ BlockArgument appendWeight(MaterializerState& state, MaterializedClass& material unsigned weightIndex = materializedClass.weights.size(); materializedClass.weights.push_back(weight); - if (auto compute = dyn_cast(materializedClass.op)) { + if (auto compute = dyn_cast(materializedClass.op)) { auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc()); assert(arg && "expected compute body while inserting a weight"); materializedClass.weightArgs[weight] = std::get<1>(*arg); return std::get<1>(*arg); } - auto batch = cast(materializedClass.op); + auto batch = cast(materializedClass.op); auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc()); assert(arg && "expected compute_batch body while inserting a weight argument"); materializedClass.weightArgs[weight] = std::get<1>(*arg); @@ -1011,13 +1135,13 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali return it->second; materializedClass.inputs.push_back(input); - if (auto compute = dyn_cast(materializedClass.op)) { + if (auto compute = dyn_cast(materializedClass.op)) { auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); assert(arg && "expected compute body while inserting an input"); materializedClass.inputArgs[input] = std::get<1>(*arg); return std::get<1>(*arg); } - if (auto compute = dyn_cast(materializedClass.op)) { + if (auto compute = dyn_cast(materializedClass.op)) { auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); assert(arg && "expected compute_batch body while inserting an input argument"); materializedClass.inputArgs[input] = std::get<1>(*arg); @@ -1026,6 +1150,580 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali llvm_unreachable("Cannot reach here"); } +Region* getParentRegion(Value value) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParent(); + if (Operation* definingOp = value.getDefiningOp()) + return definingOp->getParentRegion(); + return nullptr; +} + +bool isDefinedInsideRegion(Value value, Region& region) { + Region* parentRegion = getParentRegion(value); + return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +} + +Operation* getEnclosingSpatialComputeLikeOp(Value value) { + Block* block = nullptr; + if (auto blockArg = dyn_cast(value)) + block = blockArg.getOwner(); + else if (Operation* definingOp = value.getDefiningOp()) + block = definingOp->getBlock(); + + if (!block) + return nullptr; + + for (Operation* current = block->getParentOp(); current; current = current->getParentOp()) + if (isa(current)) + return current; + return nullptr; +} + +bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType())) + return true; + if (isConstantLike(value)) + return true; + + Region& targetRegion = *targetClass.body->getParent(); + return isDefinedInsideRegion(value, targetRegion); +} + +bool isTensorValueDefinedInDifferentMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType()) || isTensorValueLocalToMaterializedClass(value, targetClass)) + return false; + + Operation* owner = getEnclosingSpatialComputeLikeOp(value); + return owner && owner != targetClass.op; +} + +std::optional getRegionIndexInParentOp(Region* region) { + Operation* parent = region ? region->getParentOp() : nullptr; + if (!parent) + return std::nullopt; + + for (auto [index, candidate] : llvm::enumerate(parent->getRegions())) + if (&candidate == region) + return static_cast(index); + return std::nullopt; +} + +std::optional getBlockIndexInRegion(Block* block) { + Region* region = block ? block->getParent() : nullptr; + if (!region) + return std::nullopt; + + for (auto [index, candidate] : llvm::enumerate(region->getBlocks())) + if (&candidate == block) + return static_cast(index); + return std::nullopt; +} + +Block* getBlockByIndex(Region& region, unsigned blockIndex) { + unsigned index = 0; + for (Block& block : region) { + if (index == blockIndex) + return █ + ++index; + } + return nullptr; +} + +static bool isValueLegalInMaterializedClassBody(Value value, const MaterializedClass& targetClass) { + if (isConstantLike(value)) + return true; + + Region& targetRegion = *targetClass.body->getParent(); + return isDefinedInsideRegion(value, targetRegion); +} + +std::string stringifyOperationForMaterializerDebug(Operation* op) { + if (!op) + return std::string(""); + std::string storage; + llvm::raw_string_ostream stream(storage); + op->print(stream); + return storage; +} + +std::string stringifyValueForMaterializerDebug(Value value) { + std::string storage; + llvm::raw_string_ostream stream(storage); + value.print(stream); + return storage; +} + +std::string truncateMaterializerDebugString(std::string text, size_t limit = 1200) { + for (char& ch : text) + if (ch == '\n' || ch == '\r' || ch == '\t') + ch = ' '; + + if (text.size() <= limit) + return text; + text.resize(limit); + text += "..."; + return text; +} + +std::string formatMaterializerOperandListInline(Operation* op, const MaterializedClass& targetClass) { + if (!op) + return std::string(""); + + std::string storage; + llvm::raw_string_ostream stream(storage); + for (OpOperand& operand : op->getOpOperands()) { + if (operand.getOperandNumber() != 0) + stream << " | "; + Value value = operand.get(); + stream << "operand#" << operand.getOperandNumber() << " type=" << value.getType() + << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) + << " value=" << stringifyValueForMaterializerDebug(value); + if (auto blockArg = dyn_cast(value)) { + stream << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + stream << " ownerOp='" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + stream << " definingOp='" << definingOp->getName() << "'"; + } + } + return truncateMaterializerDebugString(stream.str()); +} + +std::string formatMaterializerParentChainInline(Operation* op) { + if (!op) + return std::string(""); + + std::string storage; + llvm::raw_string_ostream stream(storage); + unsigned depth = 0; + for (Operation* current = op; current; current = current->getParentOp()) { + if (depth != 0) + stream << " <- "; + stream << "[" << depth++ << "]" << current->getName(); + } + return truncateMaterializerDebugString(stream.str()); +} + +void attachMaterializerOperationPrintNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { + if (!op) + return; + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stringifyOperationForMaterializerDebug(op); +} + +void attachMaterializerParentChainNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { + if (!op) + return; + + std::string storage; + llvm::raw_string_ostream stream(storage); + unsigned depth = 0; + for (Operation* current = op; current; current = current->getParentOp()) + stream << " [" << depth++ << "] " << current->getName() << "\n"; + + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); +} + +void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, + Operation* op, + const MaterializedClass& targetClass, + StringRef label) { + if (!op) + return; + + std::string storage; + llvm::raw_string_ostream stream(storage); + for (OpOperand& operand : op->getOpOperands()) { + Value value = operand.get(); + stream << " operand#" << operand.getOperandNumber() << " type=" << value.getType() + << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) + << " value=" << stringifyValueForMaterializerDebug(value); + if (auto blockArg = dyn_cast(value)) { + stream << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + stream << " ownerOp='" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + stream << " definingOp='" << definingOp->getName() << "'"; + } + stream << "\n"; + } + + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); +} + +void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value value, StringRef label) { + if (auto blockArg = dyn_cast(value)) { + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic.attachNote(owner->getLoc()) + << label << " is block argument #" << blockArg.getArgNumber() << " of '" << owner->getName() + << "' with type " << blockArg.getType(); + else + diagnostic.attachNote(UnknownLoc::get(value.getContext())) + << label << " is a top-level block argument #" << blockArg.getArgNumber() + << " with type " << blockArg.getType(); + return; + } + + if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) + << label << " is defined by '" << definingOp->getName() << "' with result type " << value.getType(); + return; + } + + diagnostic.attachNote(UnknownLoc::get(value.getContext())) + << label << " has no defining operation and is not a block argument, type " << value.getType(); +} + +void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { + Block& body = *targetClass.body; + diagnostic.attachNote(targetClass.op->getLoc()) + << "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName() + << "' body has " << body.getNumArguments() << " block arguments and " + << std::distance(body.begin(), body.end()) << " top-level operations"; +} + +FailureOr rematerializeIndexValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Location loc, + IRMapping* mapper = nullptr); + +FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& state, + MaterializedClass& targetClass, + OpFoldResult value, + Location loc, + IRMapping* mapper = nullptr) { + if (auto attr = dyn_cast(value)) + return OpFoldResult(attr); + + FailureOr rematerialized = rematerializeIndexValueInClass(state, targetClass, cast(value), loc, mapper); + if (failed(rematerialized)) + return failure(); + return OpFoldResult(*rematerialized); +} + +FailureOr rematerializeIndexValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Location loc, + IRMapping* mapper) { + Value originalValue = value; + bool mapperHadOriginalValue = false; + Value mappedOriginalValue; + + if (mapper && mapper->contains(value)) { + mapperHadOriginalValue = true; + Value mapped = mapper->lookup(value); + mappedOriginalValue = mapped; + if (isValueLegalInMaterializedClassBody(mapped, targetClass) || isConstantLike(mapped)) + return mapped; + value = mapped; + } + + if (isValueLegalInMaterializedClassBody(value, targetClass)) + return value; + + if (!value.getType().isIndex()) + return targetClass.op->emitError("cannot rematerialize non-index external value in materialized class body") + << " type=" << value.getType(); + + if (auto constantIndex = value.getDefiningOp()) + return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantIndex.value()); + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) { + if (!constantValue.isSignedIntN(64)) + return targetClass.op->emitError("cannot rematerialize out-of-range index constant") + << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); + return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantValue.getSExtValue()); + } + + if (auto affineApply = value.getDefiningOp()) { + SmallVector remappedOperands; + remappedOperands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, operand, loc, mapper); + if (failed(remapped)) + return failure(); + remappedOperands.push_back(*remapped); + } + return createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func); + } + + if (auto addOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, addOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, addOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::AddIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto subOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, subOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, subOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::SubIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto mulOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::MulIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto divOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, divOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, divOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::DivUIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto extractOp = value.getDefiningOp()) { + SmallVector remappedIndices; + remappedIndices.reserve(extractOp.getIndices().size()); + for (Value index : extractOp.getIndices()) { + FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, index, loc, mapper); + if (failed(remapped)) + return failure(); + remappedIndices.push_back(*remapped); + } + + Value tensor = extractOp.getTensor(); + if (!isConstantLike(tensor) && !isValueLegalInMaterializedClassBody(tensor, targetClass)) + return targetClass.op->emitError("cannot rematerialize indexed table lookup from external non-constant tensor") + << " tensorType=" << tensor.getType(); + return tensor::ExtractOp::create(state.rewriter, loc, tensor, remappedIndices).getResult(); + } + + if (auto blockArg = dyn_cast(value)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body"); + diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() + << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; + if (Operation* owner = blockArg.getOwner()->getParentOp()) { + diagnostic << " ownerOp='" << owner->getName() << "'"; + diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) << "\""; + diagnostic << " ownerChain=\"" << formatMaterializerParentChainInline(owner) << "\""; + } + diagnostic << " targetIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; + if (mapper) { + diagnostic << " mapperPresent=1 mapperHadOriginal=" << (mapperHadOriginalValue ? 1 : 0); + if (mapperHadOriginalValue) + diagnostic << " mappedType=" << mappedOriginalValue.getType(); + } else { + diagnostic << " mapperPresent=0"; + } + attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); + if (value != originalValue) + attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); + if (mapperHadOriginalValue && mappedOriginalValue != value) + attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value"); + if (Operation* owner = blockArg.getOwner()->getParentOp()) { + attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op"); + attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain"); + } + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); + } + + InFlightDiagnostic diagnostic = + targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body"); + diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='" + << targetClass.op->getName() << "'"; + attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); + if (value != originalValue) + attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); +} + +InFlightDiagnostic emitNonLocalMaterializedClassValueDiagnostic(Operation* anchor, + const MaterializedClass& targetClass, + StringRef context, + Value value, + std::optional producer = std::nullopt) { + InFlightDiagnostic diagnostic = anchor->emitError(context) << " into target class " << targetClass.id; + + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() << "' resultIndex=" << producer->resultIndex + << " laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount; + } else if (auto result = dyn_cast(value)) { + diagnostic << " from '" << result.getOwner()->getName() << "' resultIndex=" << result.getResultNumber(); + } else if (auto blockArg = dyn_cast(value)) { + diagnostic << " from block argument #" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic << " of '" << owner->getName() << "'"; + } + + if (Operation* definingOp = value.getDefiningOp()) + diagnostic.attachNote(definingOp->getLoc()) << "offending tensor producer is '" << definingOp->getName() << "'"; + return diagnostic; +} + +FailureOr rematerializeTensorValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + IRMapping* mapper) { + auto extractSlice = value.getDefiningOp(); + if (extractSlice) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, targetClass, extractSlice.getSource(), anchor, context, std::nullopt, mapper); + if (failed(localizedSource)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(extractSlice.getMixedOffsets().size()); + sizes.reserve(extractSlice.getMixedSizes().size()); + strides.reserve(extractSlice.getMixedStrides().size()); + + for (OpFoldResult offset : extractSlice.getMixedOffsets()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, offset, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + offsets.push_back(*localized); + } + for (OpFoldResult size : extractSlice.getMixedSizes()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, size, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + sizes.push_back(*localized); + } + for (OpFoldResult stride : extractSlice.getMixedStrides()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, stride, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + strides.push_back(*localized); + } + + return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides) + .getResult(); + } + + if (auto collapseShape = value.getDefiningOp()) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, targetClass, collapseShape.getSrc(), anchor, context, std::nullopt, mapper); + if (failed(localizedSource)) + return failure(); + return tensor::CollapseShapeOp::create( + state.rewriter, anchor->getLoc(), *localizedSource, collapseShape.getReassociationIndices()) + .getResult(); + } + + return failure(); +} + +FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + std::optional producer, + IRMapping* mapper) { + if (mapper && mapper->contains(value)) + value = mapper->lookup(value); + + if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) + return value; + + if (value.getDefiningOp() || value.getDefiningOp()) { + FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); + if (failed(rematerialized)) + return failure(); + return *rematerialized; + } + + if (isTensorValueDefinedInDifferentMaterializedClass(value, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic(anchor, targetClass, context, value, producer); + return failure(); + } + + return appendInput(state, targetClass, value); +} + +std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, + Operation* anchor, + BlockArgument externalArg) { + Block* sourceBlock = externalArg.getOwner(); + Region* sourceRegion = sourceBlock ? sourceBlock->getParent() : nullptr; + Operation* sourceParent = sourceRegion ? sourceRegion->getParentOp() : nullptr; + if (!sourceParent || !anchor) + return std::nullopt; + + std::optional sourceRegionIndex = getRegionIndexInParentOp(sourceRegion); + std::optional sourceBlockIndex = getBlockIndexInRegion(sourceBlock); + if (!sourceRegionIndex || !sourceBlockIndex) + return std::nullopt; + + for (Operation* current = anchor->getParentOp(); current && current != targetClass.op; + current = current->getParentOp()) { + if (current->getName() != sourceParent->getName()) + continue; + if (current->getNumRegions() <= *sourceRegionIndex) + continue; + + Region& localRegion = current->getRegion(*sourceRegionIndex); + Block* localBlock = getBlockByIndex(localRegion, *sourceBlockIndex); + if (!localBlock || localBlock->getNumArguments() <= externalArg.getArgNumber()) + continue; + + BlockArgument localArg = localBlock->getArgument(externalArg.getArgNumber()); + if (localArg.getType() != externalArg.getType()) + continue; + if (!isValueLegalInMaterializedClassBody(localArg, targetClass)) + continue; + return localArg; + } + + return std::nullopt; +} + +FailureOr localizeMaterializedClassOperand(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper) { + if (mapper && mapper->contains(value)) + value = mapper->lookup(value); + + if (auto blockArg = dyn_cast(value)) + if (std::optional localArg = mapExternalRegionBlockArgumentToLocalClone(targetClass, anchor, blockArg)) + return *localArg; + + if (isa(value.getType())) + return materializeTensorValueForMaterializedClassUse(state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); + + if (isValueLegalInMaterializedClassBody(value, targetClass)) + return value; + + if (value.getType().isIndex()) + return rematerializeIndexValueInClass(state, targetClass, value, anchor->getLoc(), mapper); + + InFlightDiagnostic diagnostic = anchor->emitError(genericContext); + diagnostic << " type=" << value.getType(); + if (auto blockArg = dyn_cast(value)) { + diagnostic << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic.attachNote(owner->getLoc()) << "block argument belongs to '" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) << "unsupported external operand producer is '" << definingOp->getName() + << "'"; + } + return failure(); +} + // ----------------------------------------------------------------------------- // Tensor packing helpers. // ----------------------------------------------------------------------------- @@ -1064,6 +1762,27 @@ Value createDim0ExtractSlice( .getResult(); } +FailureOr createDim0ExtractSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value source, + OpFoldResult firstOffset, + int64_t firstSize) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + source, + targetClass.op, + "createDim0ExtractSliceInClass tried to reuse a tensor from another materialized class"); + if (failed(localizedSource)) + return failure(); + FailureOr localizedOffset = + rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + if (failed(localizedOffset)) + return failure(); + return createDim0ExtractSlice(state, loc, *localizedSource, *localizedOffset, firstSize); +} + Value createStaticExtractSlice(MaterializerState& state, Location loc, Value source, @@ -1088,6 +1807,33 @@ Value createStaticExtractSlice(MaterializerState& state, return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); } +FailureOr createStaticExtractSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value source, + ArrayRef sliceOffsets, + ArrayRef resultShape) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + source, + targetClass.op, + "createStaticExtractSliceInClass tried to reuse a tensor from another materialized class"); + if (failed(localizedSource)) + return failure(); + + SmallVector localizedOffsets; + localizedOffsets.reserve(sliceOffsets.size()); + for (OpFoldResult offset : sliceOffsets) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); + if (failed(localized)) + return failure(); + localizedOffsets.push_back(*localized); + } + return createStaticExtractSlice(state, loc, *localizedSource, localizedOffsets, resultShape); +} + Value createIndexedIndexValue(MaterializerState& state, Operation* anchor, ArrayRef values, @@ -1096,18 +1842,21 @@ Value createIndexedIndexValue(MaterializerState& state, std::optional preferredPeriod = std::nullopt, bool allowExhaustiveTiledSearch = true); -SmallVector buildProjectedFragmentOffsets(MaterializerState& state, - Operation* anchor, - const ProjectedTransferDescriptor& descriptor, - Value flatFragmentIndex, - Location loc) { +FailureOr> buildProjectedFragmentOffsetsInClass(MaterializerState& state, + MaterializedClass& targetClass, + const ProjectedTransferDescriptor& descriptor, + Value flatFragmentIndex, + Location loc) { + FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, flatFragmentIndex, loc); + if (failed(localizedIndex)) + return failure(); SmallVector fragmentOffsets; fragmentOffsets.reserve(descriptor.layout.fragmentShape.size()); for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) fragmentOffsets.push_back(createIndexedIndexValue(state, - anchor, + targetClass.op, dimOffsets, - flatFragmentIndex, + *localizedIndex, loc, static_cast(descriptor.layout.payloadFragmentCount), /*allowExhaustiveTiledSearch=*/false)); @@ -1123,6 +1872,35 @@ Value createDim0InsertSlice( .getResult(); } +FailureOr createDim0InsertSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value fragment, + Value destination, + OpFoldResult firstOffset) { + FailureOr localizedFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment, + targetClass.op, + "createDim0InsertSliceInClass tried to reuse a fragment tensor from another materialized class"); + if (failed(localizedFragment)) + return failure(); + FailureOr localizedDestination = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + destination, + targetClass.op, + "createDim0InsertSliceInClass tried to reuse a destination tensor from another materialized class"); + if (failed(localizedDestination)) + return failure(); + FailureOr localizedOffset = + rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + if (failed(localizedOffset)) + return failure(); + return createDim0InsertSlice(state, loc, *localizedFragment, *localizedDestination, *localizedOffset); +} + void createDim0ParallelInsertSlice( MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { auto fragmentType = cast(fragment.getType()); @@ -1139,6 +1917,21 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); } +FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value index, + int64_t dim0Size, + Location loc) { + FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, index, loc); + if (failed(localizedIndex)) + return failure(); + if (dim0Size == 1) + return *localizedIndex; + + Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size); + return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult(); +} + bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { return lhs.instance.op == rhs.instance.op && lhs.resultIndex == rhs.resultIndex; } @@ -1186,7 +1979,11 @@ std::optional extractPackedProducerSlice(MaterializerState& state, state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); - return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); + FailureOr slice = + createDim0ExtractSliceInClass(state, materializedClass, materializedClass.op->getLoc(), packed, firstOffset, rowCount); + if (failed(slice)) + return std::nullopt; + return *slice; } std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { @@ -1201,20 +1998,19 @@ std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId c return valueIt->second; } -Value getPackedSliceForRunIndex(MaterializerState& state, - Operation* anchor, - Value packed, - RankedTensorType fragmentType, - size_t index, - Location loc) { +FailureOr getPackedSliceForRunIndex(MaterializerState& state, + MaterializedClass& targetClass, + Value packed, + RankedTensorType fragmentType, + size_t index, + Location loc) { int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); - Value firstOffset = getOrCreateIndexConstant(state.constantFolder, anchor, rowOffset); - return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); + return createDim0ExtractSliceInClass(state, targetClass, loc, packed, firstOffset, fragmentType.getDimSize(0)); } FailureOr createReceiveConcatLoop(MaterializerState& state, - Operation* anchor, - Operation* insertionPoint, + MaterializedClass& targetClass, RankedTensorType concatType, RankedTensorType fragmentType, const MessageVector& messages, @@ -1228,6 +2024,8 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& PackedScalarRunValue& run, Location loc); +SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run); + bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { return run.kind == PackedScalarRunKind::DeferredLocalCompute; } @@ -1278,8 +2076,7 @@ FailureOr materializePackedScalarRunValue(MaterializerState& state, if (failed(fullPackedType)) return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); - auto packed = createReceiveConcatLoop( - state, targetClass.op, targetClass.body->getTerminator(), *fullPackedType, run.fragmentType, run.messages, loc); + auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc); if (failed(packed)) return failure(); run.packed = *packed; @@ -1291,14 +2088,45 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) continue; + size_t flattenedIndexBase = 0; for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { - std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); - if (!contiguousKey || !containsProducerKey(*contiguousKey, key)) - continue; + std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); + if (contiguousKey && containsProducerKey(*contiguousKey, key)) { + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return std::nullopt; - FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); - if (failed(slotPackedType)) - return std::nullopt; + MaterializedClass& materializedClass = state.classes[classId]; + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + FailureOr packed = + materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); + if (failed(packed)) + return std::nullopt; + FailureOr slotPacked = + getPackedSliceForRunIndex(state, materializedClass, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); + if (failed(slotPacked)) + return std::nullopt; + + if (*contiguousKey == key) { + record(key, classId, *slotPacked); + return *slotPacked; + } + + std::optional sliced = + extractPackedProducerSlice(state, materializedClass, *contiguousKey, *slotPacked, key); + if (!sliced) + return std::nullopt; + + record(key, classId, *sliced); + return *sliced; + } + + auto keyIt = llvm::find(slot.keys, key); + if (keyIt == slot.keys.end()) { + flattenedIndexBase += slot.keys.size(); + continue; + } MaterializedClass& materializedClass = state.classes[classId]; state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); @@ -1307,20 +2135,11 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); if (failed(packed)) return std::nullopt; - - Value slotPacked = - getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); - - if (*contiguousKey == key) { - record(key, classId, slotPacked); - return slotPacked; - } - - std::optional sliced = - extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key); - if (!sliced) + size_t flattenedIndex = flattenedIndexBase + static_cast(std::distance(slot.keys.begin(), keyIt)); + FailureOr sliced = + getPackedSliceForRunIndex(state, materializedClass, *packed, run.fragmentType, flattenedIndex, (*packed).getLoc()); + if (failed(sliced)) return std::nullopt; - record(key, classId, *sliced); return *sliced; } @@ -1360,7 +2179,6 @@ std::optional AvailableValueStore::lookup(MaterializerState& state, Produ auto valueIt = classValues.find(classId); if (valueIt == classValues.end()) continue; - std::optional slice = extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); if (!slice) @@ -1515,6 +2333,17 @@ Value createIndexedIndexValue( return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); } +OpFoldResult createIndexedOrStaticIndex(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + if (allEqual(values)) + return state.rewriter.getIndexAttr(values.front()); + return createIndexedIndexValue(state, anchor, values, index, loc); +} + Value createIndexedChannelId( MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); @@ -1567,7 +2396,7 @@ Value createLaneIndexedIndexValue(MaterializerState& state, assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); - auto batch = cast(materializedClass.op); + auto batch = cast(materializedClass.op); auto laneArg = batch.getLaneArgument(); assert(laneArg && "expected compute_batch lane argument"); @@ -1589,6 +2418,72 @@ Value createLaneIndexedIndexValue(MaterializerState& state, return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); } +FailureOr remapProjectionIndexLike(MaterializerState& state, + Operation* anchor, + OpFoldResult value, + Value sourceLaneArg, + Value mappedLaneValue, + Location loc) { + if (auto attr = dyn_cast(value)) + return value; + + Value operand = cast(value); + if (operand == sourceLaneArg) + return OpFoldResult(mappedLaneValue); + + if (matchPattern(operand, m_Constant())) + return getAsOpFoldResult(operand); + + auto affineApply = operand.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return failure(); + + SmallVector remappedOperands; + remappedOperands.reserve(affineApply.getMapOperands().size()); + for (Value mapOperand : affineApply.getMapOperands()) { + FailureOr remapped = + remapProjectionIndexLike(state, anchor, OpFoldResult(mapOperand), sourceLaneArg, mappedLaneValue, loc); + if (failed(remapped)) + return failure(); + remappedOperands.push_back(getValueOrCreateConstantIndexOp(state.rewriter, loc, *remapped)); + } + + return getAsOpFoldResult( + createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func)); +} + +FailureOr createProjectionLaneValueForKeys(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Location loc) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("projection lane mapping expects a batch materialized class"); + + auto batch = cast(sourceClass.op); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing lane argument for projected batch host publication"); + + if (keys.size() == 1) { + if (keys.front().instance.laneCount != 1) + return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); + return getOrCreateIndexConstant(state.constantFolder, sourceClass.op, keys.front().instance.laneStart); + } + + if (keys.size() != sourceClass.cpus.size()) + return batch.emitOpError("projected batch host publication expected one producer key per materialized batch lane"); + + SmallVector sourceLanes; + sourceLanes.reserve(keys.size()); + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); + sourceLanes.push_back(key.instance.laneStart); + } + + return createIndexedIndexValue(state, sourceClass.op, sourceLanes, *laneArg, loc, std::nullopt, true); +} + FailureOr> getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { SmallVector peers; @@ -1610,7 +2505,7 @@ Value createOriginalLaneValue(MaterializerState& state, if (!materializedClass.isBatch) return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart); - auto batch = cast(materializedClass.op); + auto batch = cast(materializedClass.op); auto laneArg = batch.getLaneArgument(); assert(laneArg && "expected materialized compute_batch lane argument"); @@ -1647,6 +2542,99 @@ bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) return false; } +bool hasRealComputeConsumer(Value value, const DenseSet& oldComputeOps) { + SmallVector worklist {value}; + DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand& use : current.getUses()) { + Operation* owner = use.getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + if (isa(owner)) { + for (Value result : owner->getResults()) + worklist.push_back(result); + continue; + } + if (isa(owner)) + continue; + return true; + } + } + + return false; +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); + +bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps) { + auto batch = dyn_cast_or_null(output.getDefiningOp()); + if (!batch || batch.getNumResults() == 0) + return false; + if (!hasLiveExternalUse(output, oldComputeOps)) + return false; + return !hasRealComputeConsumer(output, oldComputeOps); +} + +bool isProjectedTerminalBatchHostOutput(Value output, const DenseSet& oldComputeOps) { + if (!isTerminalHostBatchOutput(output, oldComputeOps)) + return false; + + auto batch = dyn_cast_or_null(output.getDefiningOp()); + auto originalResult = dyn_cast(output); + if (!batch || !originalResult) + return false; + + FailureOr projection = + getBatchResultProjectionInsert(batch, originalResult.getResultNumber()); + if (failed(projection)) + return false; + + return projection->getSource().getType() != output.getType(); +} + +LogicalResult emitBatchToScalarDestinationDiagnostic(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value originalOutput) { + auto diag = sourceClass.op->emitError("resultful compute_batch output would enter batch-to-scalar class fanout"); + diag << " sourceClassId=" << sourceClass.id << " sourceKind=" << (sourceClass.isBatch ? "batch" : "scalar"); + diag << " liveExternalUse=" << (hasLiveExternalUseCached(state, originalOutput) ? "true" : "false"); + diag << " terminalHostBatch=" << (isTerminalHostBatchOutput(originalOutput, state.oldComputeOps) ? "true" : "false"); + diag << " originalDef=" + << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() : StringRef("")); + + bool first = true; + diag << " destinationClasses=["; + auto destIt = state.producerDestClasses.find(keys.front()); + ArrayRef destinations = destIt == state.producerDestClasses.end() ? ArrayRef {} : ArrayRef(destIt->second); + for (ClassId classId : destinations) { + if (!first) + diag << ", "; + first = false; + const MaterializedClass& destClass = state.classes[classId]; + diag << classId << ":" << (destClass.isBatch ? "batch" : "scalar"); + } + diag << "]"; + + diag << " producerKeys=["; + first = true; + for (ProducerKey key : keys) { + if (!first) + diag << ", "; + first = false; + diag << key.instance.op->getName().getStringRef() << ":r" << key.resultIndex << ":laneStart=" << key.instance.laneStart + << ":laneCount=" << key.instance.laneCount; + } + diag << "]"; + return failure(); +} + void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { SmallVector& destinations = state.producerDestClasses[key]; if (!llvm::is_contained(destinations, classId)) @@ -1717,6 +2705,34 @@ bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceTyp return true; } + +bool isStaticSliceContainedIn(ArrayRef innerOffsets, + ArrayRef innerSizes, + ArrayRef outerOffsets, + ArrayRef outerSizes) { + if (innerOffsets.size() != innerSizes.size() || outerOffsets.size() != outerSizes.size() + || innerOffsets.size() != outerOffsets.size()) + return false; + + for (size_t dim = 0; dim < innerOffsets.size(); ++dim) { + if (innerSizes[dim] < 0 || outerSizes[dim] < 0) + return false; + + int64_t innerBegin = innerOffsets[dim]; + int64_t innerEnd = innerBegin + innerSizes[dim]; + int64_t outerBegin = outerOffsets[dim]; + int64_t outerEnd = outerBegin + outerSizes[dim]; + if (innerBegin < outerBegin || innerEnd > outerEnd) + return false; + } + + return true; +} + +bool areAllUnitStrides(ArrayRef strides) { + return llvm::all_of(strides, [](int64_t stride) { return stride == 1; }); +} + static std::optional getStaticForTripCount(scf::ForOp loop) { std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); std::optional upperBound = matchConstantIndexValue(loop.getUpperBound()); @@ -2127,6 +3143,9 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { getProjectedInputSliceMatch(state, batch, static_cast(inputIndex)); if (!match) continue; + if (!isProjectedInputSliceCompatibleWithProducerFragments( + batch, *match, producer, logicalConsumer.laneStart)) + continue; PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; if (descriptor.fragmentOffsetsByLane.empty()) { @@ -2210,6 +3229,15 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { if (payloadFragmentCount == 0) continue; + // Batch-target projected replacements currently select fragments with the + // local materialization-run slot index. That is only unambiguous when each + // target lane receives one projected fragment. Multi-fragment payloads + // need an explicit producer-key to payload-slot mapping; otherwise two + // independently materialized runs can both select fragment 0 from the same + // packed receive and duplicate rows. + if (payloadFragmentCount != 1) + continue; + bool uniform = true; for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) { if (laneFragments.size() != payloadFragmentCount) { @@ -2331,10 +3359,609 @@ ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey ke return it->second; } +std::optional getKnownMinimumIndexValue(Value value) { + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + if (auto blockArg = dyn_cast(value)) { + if (blockArg.getArgNumber() == 0) { + if (auto loop = dyn_cast_or_null(blockArg.getOwner()->getParentOp())) + return matchConstantIndexValue(loop.getLowerBound()); + } + return std::nullopt; + } + + if (auto add = value.getDefiningOp()) { + std::optional lhs = getKnownMinimumIndexValue(add.getLhs()); + std::optional rhs = getKnownMinimumIndexValue(add.getRhs()); + if (lhs && rhs) + return *lhs + *rhs; + return std::nullopt; + } + + if (auto mul = value.getDefiningOp()) { + std::optional lhs = getKnownMinimumIndexValue(mul.getLhs()); + std::optional rhs = getKnownMinimumIndexValue(mul.getRhs()); + if (!lhs || !rhs) + return std::nullopt; + if (*lhs >= 0 && *rhs >= 0) + return *lhs * *rhs; + return std::nullopt; + } + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return std::nullopt; + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + std::optional minimum = getKnownMinimumIndexValue(operand); + if (!minimum) + return std::nullopt; + operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *minimum)); + } + + SmallVector results; + if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) + return std::nullopt; + + auto intAttr = dyn_cast(results.front()); + if (!intAttr) + return std::nullopt; + return intAttr.getInt(); +} + +std::optional getKnownMinimumCommunicationChannelId(Operation* op) { + if (auto send = dyn_cast(op)) + return getKnownMinimumIndexValue(send.getChannelId()); + if (auto receive = dyn_cast(op)) + return getKnownMinimumIndexValue(receive.getChannelId()); + + std::optional minimum; + op->walk([&](Operation* nested) { + if (nested == op) + return; + std::optional nestedMinimum = getKnownMinimumCommunicationChannelId(nested); + if (!nestedMinimum) + return; + if (!minimum || *nestedMinimum < *minimum) + minimum = *nestedMinimum; + }); + return minimum; +} + +void setInsertionPointForScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + int64_t channelId) { + assert(!targetClass.isBatch && "scalar receive ordering expects a scalar target class"); + + for (Operation& op : *targetClass.body) { + if (op.hasTrait()) + break; + + std::optional existingChannel = getKnownMinimumCommunicationChannelId(&op); + if (existingChannel && *existingChannel > channelId) { + state.rewriter.setInsertionPoint(&op); + return; + } + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); +} + // ----------------------------------------------------------------------------- // Communication materialization helpers. // ----------------------------------------------------------------------------- +constexpr const char* kRaptorMinChannelIdAttr = "raptor.min_channel_id"; +constexpr const char* kRaptorMaterializerAttr = "raptor.materializer"; +constexpr const char* kRaptorCommTraceIdAttr = "raptor.comm_trace_id"; +constexpr const char* kRaptorCommTraceKindAttr = "raptor.comm_trace_kind"; +constexpr const char* kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase"; +constexpr const char* kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id"; +constexpr const char* kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind"; +constexpr const char* kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal"; +constexpr const char* kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload"; +constexpr const char* kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages"; +constexpr const char* kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op"; +constexpr const char* kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op"; + +int64_t getMinimumChannelId(ArrayRef channelIds) { + assert(!channelIds.empty() && "expected at least one channel id"); + int64_t minChannelId = channelIds.front(); + for (int64_t channelId : channelIds.drop_front()) + if (channelId < minChannelId) + minChannelId = channelId; + return minChannelId; +} + +SmallVector getScalarSendChannelOrder(const MessageVector& messages) { + SmallVector order; + order.reserve(messages.size()); + for (size_t i = 0, e = messages.size(); i < e; ++i) + order.push_back(i); + + llvm::sort(order, [&](size_t lhs, size_t rhs) { + if (messages.channelIds[lhs] != messages.channelIds[rhs]) + return messages.channelIds[lhs] < messages.channelIds[rhs]; + if (messages.sourceCoreIds[lhs] != messages.sourceCoreIds[rhs]) + return messages.sourceCoreIds[lhs] < messages.sourceCoreIds[rhs]; + return messages.targetCoreIds[lhs] < messages.targetCoreIds[rhs]; + }); + return order; +} + +MessageVector reorderMessages(const MessageVector& messages, ArrayRef order) { + MessageVector reordered; + reordered.channelIds.reserve(messages.size()); + reordered.sourceCoreIds.reserve(messages.size()); + reordered.targetCoreIds.reserve(messages.size()); + for (size_t index : order) + reordered.append(messages.channelIds[index], messages.sourceCoreIds[index], messages.targetCoreIds[index]); + return reordered; +} + +MessageVector reorderScalarSendMessagesByChannel(const MessageVector& messages) { + return reorderMessages(messages, getScalarSendChannelOrder(messages)); +} + +ProjectedTransferDescriptor reorderProjectedDescriptorByMessageOrder(const ProjectedTransferDescriptor& descriptor, + ArrayRef order) { + ProjectedTransferDescriptor reordered = descriptor; + size_t payloadFragmentCount = static_cast(descriptor.layout.payloadFragmentCount); + reordered.fragmentOffsets.clear(); + reordered.fragmentOffsets.reserve(descriptor.fragmentOffsets.size()); + for (size_t messageIndex : order) { + size_t offset = messageIndex * payloadFragmentCount; + for (size_t fragmentIndex = 0; fragmentIndex < payloadFragmentCount; ++fragmentIndex) + reordered.fragmentOffsets.push_back(descriptor.fragmentOffsets[offset + fragmentIndex]); + } + reordered.fragmentOffsetsByDim.clear(); + return reordered; +} + + +Operation* getPayloadDefiningOpInClassBlock(Value payload, MaterializedClass& materializedClass) { + Operation* definingOp = payload.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != materializedClass.body) + return nullptr; + return definingOp; +} + +Operation* findScalarCommunicationInsertionPoint(MaterializedClass& materializedClass, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + Operation* terminator = materializedClass.body->getTerminator(); + bool afterLowerBound = lowerBound == nullptr; + + for (Operation& op : *materializedClass.body) { + if (&op == terminator) + break; + + if (!afterLowerBound) { + if (&op == lowerBound) + afterLowerBound = true; + continue; + } + + if (&op == lowerBound) + continue; + + auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); + if (existingMinChannel && existingMinChannel.getInt() > minChannelId) + return &op; + } + + return terminator; +} + +void setInsertionPointForScalarCommunication(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + state.rewriter.setInsertionPoint( + findScalarCommunicationInsertionPoint(materializedClass, minChannelId, lowerBound)); +} + +constexpr const char kRaptorCommOrderAttr[] = "raptor.comm_order"; + +int64_t computeBlockingCommunicationOrderKey(int32_t sourceCoreId, int32_t targetCoreId, int64_t channelId) { + int64_t lowCore = std::min(sourceCoreId, targetCoreId); + int64_t highCore = std::max(sourceCoreId, targetCoreId); + int64_t directionPhase = sourceCoreId <= targetCoreId ? 0 : 1; + return (((lowCore * 1000000LL + highCore) * 2LL + directionPhase) * 1000000000LL) + channelId; +} + +int64_t getMinimumBlockingCommunicationOrderKey(const MessageVector& messages) { + assert(!messages.empty() && "expected at least one message"); + int64_t best = computeBlockingCommunicationOrderKey( + messages.sourceCoreIds.front(), messages.targetCoreIds.front(), messages.channelIds.front()); + for (size_t index = 1, end = messages.size(); index < end; ++index) { + best = std::min(best, computeBlockingCommunicationOrderKey( + messages.sourceCoreIds[index], messages.targetCoreIds[index], messages.channelIds[index])); + } + return best; +} + +Operation* findScalarCommunicationInsertionPointByOrder(MaterializedClass& materializedClass, + int64_t orderKey, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + Operation* terminator = materializedClass.body->getTerminator(); + bool afterLowerBound = lowerBound == nullptr; + + for (Operation& op : *materializedClass.body) { + if (&op == terminator) + break; + + if (!afterLowerBound) { + if (&op == lowerBound) + afterLowerBound = true; + continue; + } + + if (&op == lowerBound) + continue; + + if (auto existingOrder = op.getAttrOfType(kRaptorCommOrderAttr)) { + if (existingOrder.getInt() > orderKey) + return &op; + continue; + } + + auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); + if (existingMinChannel && existingMinChannel.getInt() > minChannelId) + return &op; + } + + return terminator; +} + +void setInsertionPointForScalarCommunicationOrder(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t orderKey, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + if (!pimMaterializeScalarFanoutGlobalOrder) { + setInsertionPointForScalarCommunication(state, materializedClass, minChannelId, lowerBound); + return; + } + + state.rewriter.setInsertionPoint( + findScalarCommunicationInsertionPointByOrder(materializedClass, orderKey, minChannelId, lowerBound)); +} + +void markScalarCommunication(Operation* op, int64_t minChannelId, StringRef materializer = StringRef()) { + if (!op) + return; + op->setAttr(kRaptorMinChannelIdAttr, + IntegerAttr::get(IndexType::get(op->getContext()), minChannelId)); + if (!materializer.empty()) + op->setAttr(kRaptorMaterializerAttr, StringAttr::get(op->getContext(), materializer)); +} + +void markScalarCommunicationOrder(Operation* op, int64_t orderKey) { + if (!op) + return; + op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(IndexType::get(op->getContext()), orderKey)); +} + +std::optional getOperationOrdinalInBlock(Operation* op) { + if (!op || !op->getBlock()) + return std::nullopt; + + int64_t ordinal = 0; + for (Operation& candidate : *op->getBlock()) { + if (&candidate == op) + return ordinal; + ++ordinal; + } + return std::nullopt; +} + +std::string formatOperationForTrace(Operation* op) { + if (!op) + return ""; + + std::string text; + llvm::raw_string_ostream os(text); + os << op->getName().getStringRef(); + if (auto ordinal = getOperationOrdinalInBlock(op)) + os << "@" << *ordinal; + return os.str(); +} + +std::string formatValueForTrace(Value value, Block* localBody) { + if (!value) + return ""; + + std::string text; + llvm::raw_string_ostream os(text); + if (auto arg = dyn_cast(value)) { + os << "block_arg#" << arg.getArgNumber(); + return os.str(); + } + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) { + os << "external"; + return os.str(); + } + + os << definingOp->getName().getStringRef(); + if (definingOp->getBlock() == localBody) { + if (auto ordinal = getOperationOrdinalInBlock(definingOp)) + os << "@" << *ordinal; + } + else { + os << "@external-block"; + } + return os.str(); +} + +std::string formatClassForTrace(const MaterializedClass& materializedClass) { + std::string text; + llvm::raw_string_ostream os(text); + os << (materializedClass.isBatch ? "batch" : "scalar") << " class " << materializedClass.id << " cpus=["; + for (auto [index, cpu] : llvm::enumerate(materializedClass.cpus)) { + if (index) + os << ","; + os << cpu; + } + os << "]"; + return os.str(); +} + +std::string formatMessagesForTrace(const MessageVector& messages, unsigned maxMessages = 8) { + std::string text; + llvm::raw_string_ostream os(text); + os << "count=" << messages.size() << " ["; + unsigned limit = std::min(maxMessages, messages.size()); + for (unsigned index = 0; index < limit; ++index) { + if (index) + os << "; "; + os << "c" << messages.channelIds[index] << ":" << messages.sourceCoreIds[index] + << "->" << messages.targetCoreIds[index]; + } + if (messages.size() > limit) + os << "; ..."; + os << "]"; + return os.str(); +} + +void annotateCommunicationMaterialization(MaterializerState& state, + MaterializedClass& materializedClass, + Operation* op, + StringRef kind, + StringRef materializer, + StringRef phase, + std::optional minChannelId, + std::optional orderKey, + Value payload = Value(), + const MessageVector* messages = nullptr) { + if (!op) + return; + + MLIRContext* context = op->getContext(); + int64_t traceId = state.nextCommunicationTraceId++; + auto indexType = IndexType::get(context); + op->setAttr(kRaptorCommTraceIdAttr, IntegerAttr::get(indexType, traceId)); + op->setAttr(kRaptorCommTraceKindAttr, StringAttr::get(context, kind)); + op->setAttr(kRaptorCommTracePhaseAttr, StringAttr::get(context, phase)); + op->setAttr(kRaptorCommTraceClassIdAttr, IntegerAttr::get(indexType, materializedClass.id)); + op->setAttr(kRaptorCommTraceClassKindAttr, + StringAttr::get(context, materializedClass.isBatch ? "batch" : "scalar")); + if (!materializer.empty()) + op->setAttr(kRaptorMaterializerAttr, StringAttr::get(context, materializer)); + if (minChannelId) + op->setAttr(kRaptorMinChannelIdAttr, IntegerAttr::get(indexType, *minChannelId)); + if (orderKey) + op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(indexType, *orderKey)); + if (auto ordinal = getOperationOrdinalInBlock(op)) + op->setAttr(kRaptorCommTraceBlockOrdinalAttr, IntegerAttr::get(indexType, *ordinal)); + op->setAttr(kRaptorCommTracePayloadAttr, + StringAttr::get(context, formatValueForTrace(payload, materializedClass.body))); + if (messages) + op->setAttr(kRaptorCommTraceMessagesAttr, StringAttr::get(context, formatMessagesForTrace(*messages))); + + Operation* prev = op->getPrevNode(); + Operation* next = op->getNextNode(); + op->setAttr(kRaptorCommTracePrevOpAttr, StringAttr::get(context, formatOperationForTrace(prev))); + op->setAttr(kRaptorCommTraceNextOpAttr, StringAttr::get(context, formatOperationForTrace(next))); + + if (!pimTraceCommunicationMaterialization) + return; + + llvm::errs() << "[raptor:comm-materializer] #" << traceId << " " << kind + << " via " << materializer << " phase=" << phase << " " + << formatClassForTrace(materializedClass); + if (minChannelId) + llvm::errs() << " min_channel=" << *minChannelId; + if (orderKey) + llvm::errs() << " order=" << *orderKey; + if (auto ordinal = getOperationOrdinalInBlock(op)) + llvm::errs() << " block_ordinal=" << *ordinal; + llvm::errs() << " payload=" << formatValueForTrace(payload, materializedClass.body); + if (messages) + llvm::errs() << " messages=" << formatMessagesForTrace(*messages); + llvm::errs() << " prev=" << formatOperationForTrace(prev) + << " next=" << formatOperationForTrace(next) << "\n"; +} + +void setInsertionPointForEarlyCommunication(MaterializerState& state, MaterializedClass& materializedClass) { + auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); + if (lateIt != state.firstLateCommunicationOps.end() && lateIt->second && lateIt->second->getBlock()) { + state.rewriter.setInsertionPoint(lateIt->second); + return; + } + + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); +} + +void setInsertionPointForLateCommunication(MaterializerState& state, MaterializedClass& materializedClass) { + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); +} + + +Operation* findLateScalarCommunicationInsertionPoint(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t minChannelId) { + Operation* terminator = materializedClass.body->getTerminator(); + auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); + Operation* firstLate = lateIt == state.firstLateCommunicationOps.end() ? nullptr : lateIt->second; + if (!firstLate || firstLate->getBlock() != materializedClass.body) + return terminator; + + bool inLateRegion = false; + for (Operation& op : *materializedClass.body) { + if (&op == terminator) + break; + + if (!inLateRegion) { + if (&op == firstLate) + inLateRegion = true; + else + continue; + } + + auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); + if (existingMinChannel && existingMinChannel.getInt() > minChannelId) + return &op; + } + + return terminator; +} + +void setInsertionPointForLateScalarCommunication(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t minChannelId) { + state.rewriter.setInsertionPoint( + findLateScalarCommunicationInsertionPoint(state, materializedClass, minChannelId)); +} + +void rememberLateCommunicationOp(MaterializerState& state, MaterializedClass& materializedClass, Operation* op) { + if (!op || op->getBlock() != materializedClass.body) + return; + + Operation*& firstLate = state.firstLateCommunicationOps[materializedClass.id]; + if (!firstLate || firstLate->getBlock() != materializedClass.body || op->isBeforeInBlock(firstLate)) + firstLate = op; +} + + + +constexpr const char kMinCommunicationChannelIdAttr[] = "raptor.min_channel_id"; + +std::optional getConstantIndexValue(Value value) { + APInt constant; + if (matchPattern(value, m_ConstantInt(&constant))) + return constant.getSExtValue(); + return std::nullopt; +} + +std::optional getCommunicationChannelId(Operation& op) { + if (auto attr = op.getAttrOfType(kMinCommunicationChannelIdAttr)) + return attr.getInt(); + + if (auto send = dyn_cast(&op)) + return getConstantIndexValue(send.getChannelId()); + if (auto receive = dyn_cast(&op)) + return getConstantIndexValue(receive.getChannelId()); + + return std::nullopt; +} + +int64_t getMinimumCommunicationChannelId(const MessageVector& messages) { + assert(!messages.empty() && "expected at least one message"); + return *std::min_element(messages.channelIds.begin(), messages.channelIds.end()); +} + +void markCommunicationChannelId(Operation* op, int64_t channelId) { + if (!op) + return; + op->setAttr(kMinCommunicationChannelIdAttr, + IntegerAttr::get(IntegerType::get(op->getContext(), 64), channelId)); +} + +Operation* getSameBlockDefiningOp(Value value, Block* block) { + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != block) + return nullptr; + return definingOp; +} + + +bool valueDependsOnChannelReceive(Value root) { + SmallVector worklist; + DenseSet visitedValues; + DenseSet visitedOps; + worklist.push_back(root); + + auto visitOperand = [&](Value value) { + if (value && visitedValues.insert(value).second) + worklist.push_back(value); + }; + + while (!worklist.empty()) { + Value value = worklist.pop_back_val(); + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || !visitedOps.insert(definingOp).second) + continue; + + if (isa(definingOp)) + return true; + + for (Value operand : definingOp->getOperands()) + visitOperand(operand); + + for (Region& region : definingOp->getRegions()) { + for (Block& block : region) { + for (Operation& nested : block) { + for (Value operand : nested.getOperands()) + visitOperand(operand); + } + } + } + } + + return false; +} + +bool shouldDelayScalarSendUntilAfterReceives(Value payload, int32_t sourceCoreId, int32_t targetCoreId) { + if (sourceCoreId <= targetCoreId) + return false; + return valueDependsOnChannelReceive(payload); +} + +void partitionScalarMessagesByReceiveDependency(Value payload, + const MessageVector& messages, + MessageVector& earlyMessages, + MessageVector& lateMessages) { + for (size_t i = 0, e = messages.size(); i < e; ++i) { + MessageVector& bucket = shouldDelayScalarSendUntilAfterReceives( + payload, messages.sourceCoreIds[i], messages.targetCoreIds[i]) + ? lateMessages + : earlyMessages; + bucket.append(messages.channelIds[i], messages.sourceCoreIds[i], messages.targetCoreIds[i]); + } +} + +void setInsertionPointForScalarSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + int64_t minChannelId, + bool late) { + if (late) { + setInsertionPointForLateScalarCommunication(state, sourceClass, minChannelId); + return; + } + + setInsertionPointForScalarCommunication( + state, sourceClass, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); +} + + void appendScalarSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -2344,24 +3971,43 @@ void appendScalarSend(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + bool late = shouldDelayScalarSendUntilAfterReceives(payload, sourceCoreId, targetCoreId); + int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); + if (pimMaterializeScalarFanoutGlobalOrder) + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, orderKey, channelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + else + setInsertionPointForScalarSend(state, sourceClass, payload, channelId, late); Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); - SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); + auto send = SpatChannelSendOp::create( + state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); + markScalarCommunication(send.getOperation(), channelId, "appendScalarSend"); + markScalarCommunicationOrder(send.getOperation(), orderKey); + MessageVector traceMessages; + traceMessages.append(channelId, sourceCoreId, targetCoreId); + annotateCommunicationMaterialization(state, + sourceClass, + send.getOperation(), + "send", + "appendScalarSend", + late ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), + channelId, + orderKey, + payload, + &traceMessages); + if (late && !pimMaterializeScalarFanoutGlobalOrder) + rememberLateCommunicationOp(state, sourceClass, send.getOperation()); } -LogicalResult appendScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); - assert(messages.size() > 1 && "send loop is only useful for multiple sends"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - +LogicalResult emitScalarSendLoopAtInsertionPoint(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + int64_t minChannelId, + int64_t orderKey, + Location loc) { Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); @@ -2383,27 +4029,69 @@ LogicalResult appendScalarSendLoop(MaterializerState& state, }); if (failed(sendLoop)) return failure(); + markScalarCommunication(sendLoop->loop.getOperation(), minChannelId, "appendScalarSendLoop"); + markScalarCommunicationOrder(sendLoop->loop.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + sourceClass, + sendLoop->loop.getOperation(), + "send-loop", + "appendScalarSendLoop", + "loop", + minChannelId, + orderKey, + payload, + &messages); return success(); } +LogicalResult appendScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); + assert(messages.size() > 1 && "send loop is only useful for multiple sends"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + MessageVector orderedMessages = reorderScalarSendMessagesByChannel(messages); + if (pimMaterializeScalarFanoutGlobalOrder) { + for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) + appendScalarSend(state, + sourceClass, + payload, + orderedMessages.channelIds[index], + orderedMessages.sourceCoreIds[index], + orderedMessages.targetCoreIds[index], + loc); + return success(); + } + + int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + return emitScalarSendLoopAtInsertionPoint(state, sourceClass, payload, orderedMessages, minChannelId, orderKey, loc); +} + + FailureOr buildProjectedPackedPayload(MaterializerState& state, - Operation* anchor, + MaterializedClass& targetClass, Value fullPayload, const ProjectedTransferDescriptor& descriptor, Value messageIndex, Location loc) { - if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) return failure(); if (descriptor.layout.payloadFragmentCount == 1) - return anchor->emitError("projected packed payload builder expects a packed payload"); + return targetClass.op->emitError("projected packed payload builder expects a packed payload"); Value init = tensor::EmptyOp::create( state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) .getResult(); - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.layout.payloadFragmentCount); - Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); auto loop = buildNormalizedScfFor( state.rewriter, @@ -2415,19 +4103,30 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value acc = iterArgs.front(); Value payloadFragmentCount = - getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.layout.payloadFragmentCount); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, messageIndex, payloadFragmentCount).getResult(); + getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + FailureOr localMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); + if (failed(localMessageIndex)) + return failure(); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localMessageIndex, payloadFragmentCount).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - SmallVector fragmentOffsets = - buildProjectedFragmentOffsets(state, anchor, descriptor, flatIndex, loc); - Value fragment = - createStaticExtractSlice(state, loc, fullPayload, fragmentOffsets, descriptor.layout.fragmentShape); + FailureOr> fragmentOffsets = + buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc); + if (failed(fragmentOffsets)) + return failure(); + FailureOr fragment = createStaticExtractSliceInClass( + state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + if (failed(fragment)) + return failure(); - Value packedOffset = - scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); - Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset); - yielded.push_back(next); + FailureOr packedOffset = scaleIndexByDim0SizeInClass( + state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); + if (failed(packedOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); return success(); }); if (failed(loop)) @@ -2436,21 +4135,24 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, } FailureOr buildProjectedPayloadForMessage(MaterializerState& state, - Operation* anchor, + MaterializedClass& targetClass, Value fullPayload, const ProjectedTransferDescriptor& descriptor, Value messageIndex, Location loc) { - if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) return failure(); if (descriptor.layout.payloadFragmentCount == 1) { - SmallVector fragmentOffsets = - buildProjectedFragmentOffsets(state, anchor, descriptor, messageIndex, loc); - return createStaticExtractSlice(state, loc, fullPayload, fragmentOffsets, descriptor.layout.fragmentShape); + FailureOr> fragmentOffsets = + buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, messageIndex, loc); + if (failed(fragmentOffsets)) + return failure(); + return createStaticExtractSliceInClass( + state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); } - return buildProjectedPackedPayload(state, anchor, fullPayload, descriptor, messageIndex, loc); + return buildProjectedPackedPayload(state, targetClass, fullPayload, descriptor, messageIndex, loc); } LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, @@ -2461,27 +4163,59 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - if (failed(verifyProjectedSendDescriptor(sourceClass.op, descriptor, messages))) + + SmallVector messageOrder = getScalarSendChannelOrder(messages); + MessageVector orderedMessages = reorderMessages(messages, messageOrder); + ProjectedTransferDescriptor orderedDescriptor = reorderProjectedDescriptorByMessageOrder(descriptor, messageOrder); + if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, orderedDescriptor))) + return failure(); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, orderedDescriptor, orderedMessages))) return failure(); - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - if (messages.size() == 1) { - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); - Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass.op, payload, descriptor, messageIndex, loc); - if (failed(sendPayload)) - return failure(); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); + if (orderedMessages.size() == 1 || pimMaterializeScalarFanoutGlobalOrder) { + for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) { + int64_t channel = orderedMessages.channelIds[index]; + int32_t sourceCore = orderedMessages.sourceCoreIds[index]; + int32_t targetCore = orderedMessages.targetCoreIds[index]; + int64_t localOrderKey = computeBlockingCommunicationOrderKey(sourceCore, targetCore, channel); + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, localOrderKey, channel, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channel); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCore); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCore); + Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(index)); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, messageIndex, loc); + if (failed(sendPayload)) + return failure(); + auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); + markScalarCommunication(send.getOperation(), channel, "appendProjectedScalarSendLoop.single"); + markScalarCommunicationOrder(send.getOperation(), localOrderKey); + MessageVector traceMessages; + traceMessages.append(channel, sourceCore, targetCore); + annotateCommunicationMaterialization(state, + sourceClass, + send.getOperation(), + "send", + "appendProjectedScalarSendLoop.single", + "projected-single", + channel, + localOrderKey, + *sendPayload, + &traceMessages); + } return success(); } Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(orderedMessages.size())); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); auto projectedSendLoop = buildNormalizedScfFor( @@ -2492,11 +4226,11 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, step, ValueRange {}, [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); + Value channelId = createIndexedChannelId(state, sourceClass.op, orderedMessages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, orderedMessages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, orderedMessages, index, loc); FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass.op, payload, descriptor, index, loc); + buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, index, loc); if (failed(sendPayload)) return failure(); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); @@ -2504,9 +4238,22 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, }); if (failed(projectedSendLoop)) return failure(); + markScalarCommunication(projectedSendLoop->loop.getOperation(), minChannelId, "appendProjectedScalarSendLoop.loop"); + markScalarCommunicationOrder(projectedSendLoop->loop.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + sourceClass, + projectedSendLoop->loop.getOperation(), + "send-loop", + "appendProjectedScalarSendLoop.loop", + "projected-loop", + minChannelId, + orderKey, + payload, + &orderedMessages); return success(); } + LogicalResult appendSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -2521,7 +4268,21 @@ LogicalResult appendSend(MaterializerState& state, Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + int64_t minChannelId = getMinimumChannelId(messages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); + markScalarCommunication(send.getOperation(), minChannelId, "appendSend.batch"); + markScalarCommunicationOrder(send.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + sourceClass, + send.getOperation(), + "send", + "appendSend.batch", + "batch-lane-indexed", + minChannelId, + orderKey, + payload, + &messages); return success(); } @@ -2545,29 +4306,74 @@ Value appendScalarReceive(MaterializerState& state, int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId, - Location loc) { + Location loc, + bool lateReceive = false) { assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); + if (lateReceive) + setInsertionPointForLateScalarCommunication(state, targetClass, channelId); + else + setInsertionPointForScalarCommunicationOrder(state, targetClass, orderKey, channelId); Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); - return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) - .getOutput(); + auto receive = SpatChannelReceiveOp::create( + state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); + markScalarCommunication(receive.getOperation(), channelId, + lateReceive ? "appendScalarReceive.late" : "appendScalarReceive"); + markScalarCommunicationOrder(receive.getOperation(), orderKey); + MessageVector traceMessages; + traceMessages.append(channelId, sourceCoreId, targetCoreId); + annotateCommunicationMaterialization(state, + targetClass, + receive.getOperation(), + "receive", + lateReceive ? "appendScalarReceive.late" : "appendScalarReceive", + lateReceive ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), + channelId, + orderKey, + Value(), + &traceMessages); + return receive.getOutput(); } + Value appendReceive( - MaterializerState& state, MaterializedClass& targetClass, Type type, const MessageVector& messages, Location loc) { + MaterializerState& state, + MaterializedClass& targetClass, + Type type, + const MessageVector& messages, + Location loc, + bool lateReceive = false) { assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); assert(!messages.empty() && "expected at least one receive"); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + if (lateReceive) + setInsertionPointForLateScalarCommunication(state, targetClass, getMinimumChannelId(messages.channelIds)); + else + setInsertionPointForEarlyCommunication(state, targetClass); if (targetClass.isBatch) { Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); - return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput(); + auto receive = SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId); + int64_t minChannelId = getMinimumChannelId(messages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); + markScalarCommunication(receive.getOperation(), minChannelId, "appendReceive.batch"); + markScalarCommunicationOrder(receive.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + targetClass, + receive.getOperation(), + "receive", + "appendReceive.batch", + lateReceive ? "late-batch" : "early-batch", + minChannelId, + orderKey, + Value(), + &messages); + return receive.getOutput(); } assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); @@ -2577,7 +4383,141 @@ Value appendReceive( messages.channelIds.front(), messages.sourceCoreIds.front(), messages.targetCoreIds.front(), - loc); + loc, + lateReceive); +} + +Value appendScalarReceiveAtCurrentInsertionPoint(MaterializerState& state, + MaterializedClass& targetClass, + Type type, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!targetClass.isBatch && "demand scalar receive expects a scalar target class"); + + int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); + auto receive = SpatChannelReceiveOp::create( + state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); + markScalarCommunication(receive.getOperation(), channelId, "appendScalarReceive.demand"); + markScalarCommunicationOrder(receive.getOperation(), orderKey); + MessageVector traceMessages; + traceMessages.append(channelId, sourceCoreId, targetCoreId); + annotateCommunicationMaterialization(state, + targetClass, + receive.getOperation(), + "receive", + "appendScalarReceive.demand", + "demand", + channelId, + orderKey, + Value(), + &traceMessages); + return receive.getOutput(); +} + +std::optional lookupPendingScalarReceiveIndex(MaterializerState& state, + ProducerKey key, + ClassId targetClassId) { + auto keyIt = state.pendingScalarReceiveLookup.find(key); + if (keyIt == state.pendingScalarReceiveLookup.end()) + return std::nullopt; + + auto classIt = keyIt->second.find(targetClassId); + if (classIt == keyIt->second.end()) + return std::nullopt; + return classIt->second; +} + +void recordPendingScalarReceive(MaterializerState& state, + ClassId targetClassId, + ArrayRef keys, + Type receiveType, + const MessageVector& messages, + Location loc) { + if (keys.empty()) + return; + + if (lookupPendingScalarReceiveIndex(state, keys.front(), targetClassId)) + return; + + size_t recordIndex = state.pendingScalarReceives.size(); + state.pendingScalarReceives.emplace_back(keys, targetClassId, receiveType, messages, loc); + + for (ProducerKey key : keys) + state.pendingScalarReceiveLookup[key][targetClassId] = recordIndex; +} + +FailureOr materializePendingScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + size_t recordIndex, + Location loc) { + if (recordIndex >= state.pendingScalarReceives.size()) + return targetClass.op->emitError("pending scalar receive index is out of bounds"); + + PendingScalarReceiveRecord& record = state.pendingScalarReceives[recordIndex]; + if (record.targetClassId != targetClass.id) + return targetClass.op->emitError("pending scalar receive target class mismatch"); + + if (record.materialized) + return record.value; + + if (targetClass.isBatch) + return targetClass.op->emitError("pending scalar receive cannot materialize into a batch class"); + if (record.messages.size() != 1) + return targetClass.op->emitError("pending scalar receive expected exactly one scalar message"); + + Location receiveLoc = loc; + Value received = appendScalarReceiveAtCurrentInsertionPoint(state, + targetClass, + record.receiveType, + record.messages.channelIds.front(), + record.messages.sourceCoreIds.front(), + record.messages.targetCoreIds.front(), + receiveLoc); + record.materialized = true; + record.value = received; + + for (ProducerKey key : record.keys) + state.availableValues.record(key, targetClass.id, received); + + return received; +} + + +LogicalResult materializePendingScalarReceivesForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey wholeBatchKey, + Location loc) { + if (targetClass.isBatch || !isWholeBatchProducerKey(wholeBatchKey)) + return success(); + + SmallVector pendingIndices; + for (auto [recordIndex, record] : llvm::enumerate(state.pendingScalarReceives)) { + if (record.targetClassId != targetClass.id || record.materialized) + continue; + + bool contributesToWholeBatch = llvm::any_of(record.keys, [&](ProducerKey fragmentKey) { + return containsProducerKey(wholeBatchKey, fragmentKey); + }); + if (contributesToWholeBatch) + pendingIndices.push_back(recordIndex); + } + + if (pendingIndices.empty()) + return success(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + for (size_t recordIndex : pendingIndices) { + FailureOr received = materializePendingScalarReceive(state, targetClass, recordIndex, loc); + if (failed(received)) + return failure(); + } + + return success(); } LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, @@ -2653,6 +4593,7 @@ struct ScalarSourceReceivePlan { Type receiveType; Operation* projectedExtractOp = nullptr; ProjectedFragmentLayout projectedLayout; + std::optional projectedDescriptor; }; struct ProjectedScalarSendGroup { @@ -2786,6 +4727,7 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& receivePlan.receiveType = projectedDescriptor.payloadType; receivePlan.projectedExtractOp = projectedDescriptor.extractOp; receivePlan.projectedLayout = projectedDescriptor.layout; + receivePlan.projectedDescriptor = projectedDescriptor; auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); @@ -2837,6 +4779,145 @@ LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, return success(); } + +struct GloballyOrderedScalarFanoutEvent { + size_t receivePlanIndex = 0; + int64_t minChannelId = 0; + int64_t orderKey = 0; + int32_t minSourceCoreId = 0; + int32_t minTargetCoreId = 0; +}; + +GloballyOrderedScalarFanoutEvent makeGloballyOrderedScalarFanoutEvent(size_t receivePlanIndex, + const ScalarSourceReceivePlan& plan) { + assert(!plan.messages.empty() && "expected a communication event with at least one message"); + GloballyOrderedScalarFanoutEvent event; + event.receivePlanIndex = receivePlanIndex; + event.minChannelId = plan.messages.channelIds.front(); + event.orderKey = getMinimumBlockingCommunicationOrderKey(plan.messages); + event.minSourceCoreId = plan.messages.sourceCoreIds.front(); + event.minTargetCoreId = plan.messages.targetCoreIds.front(); + + for (size_t index = 1, end = plan.messages.size(); index < end; ++index) { + event.minChannelId = std::min(event.minChannelId, plan.messages.channelIds[index]); + event.minSourceCoreId = std::min(event.minSourceCoreId, plan.messages.sourceCoreIds[index]); + event.minTargetCoreId = std::min(event.minTargetCoreId, plan.messages.targetCoreIds[index]); + } + + return event; +} + +SmallVector +collectGloballyOrderedScalarFanoutEvents(const ScalarSourceFanoutPlan& plan) { + SmallVector events; + events.reserve(plan.receivePlans.size()); + + for (auto [index, receivePlan] : llvm::enumerate(plan.receivePlans)) + if (!receivePlan.messages.empty()) + events.push_back(makeGloballyOrderedScalarFanoutEvent(index, receivePlan)); + + llvm::sort(events, [](const GloballyOrderedScalarFanoutEvent& lhs, + const GloballyOrderedScalarFanoutEvent& rhs) { + if (lhs.orderKey != rhs.orderKey) + return lhs.orderKey < rhs.orderKey; + if (lhs.minChannelId != rhs.minChannelId) + return lhs.minChannelId < rhs.minChannelId; + if (lhs.minSourceCoreId != rhs.minSourceCoreId) + return lhs.minSourceCoreId < rhs.minSourceCoreId; + if (lhs.minTargetCoreId != rhs.minTargetCoreId) + return lhs.minTargetCoreId < rhs.minTargetCoreId; + return lhs.receivePlanIndex < rhs.receivePlanIndex; + }); + + return events; +} + +LogicalResult emitGloballyOrderedScalarFanoutSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ScalarSourceReceivePlan& plan, + Location loc) { + if (plan.projectedDescriptor) + return appendProjectedScalarSendLoop(state, sourceClass, payload, *plan.projectedDescriptor, plan.messages, loc); + + return appendSend(state, sourceClass, payload, plan.messages, loc); +} + +bool isMaterializedBlockingCommunication(Operation& op) { + return isa(&op) || op.hasAttr(kRaptorMinChannelIdAttr) + || op.hasAttr(kRaptorCommOrderAttr); +} + +bool payloadIsAvailableOnlyAfterPriorCommunication(Value payload, MaterializedClass& sourceClass) { + Operation* lowerBound = getPayloadDefiningOpInClassBlock(payload, sourceClass); + if (!lowerBound) + return false; + + bool sawPriorCommunication = false; + Operation* terminator = sourceClass.body->getTerminator(); + for (Operation& op : *sourceClass.body) { + if (&op == terminator) + break; + + if (&op == lowerBound) + return sawPriorCommunication || isMaterializedBlockingCommunication(op); + + if (isMaterializedBlockingCommunication(op)) + sawPriorCommunication = true; + } + + return sawPriorCommunication; +} + +bool shouldPlaceMatchingScalarFanoutReceiveLate(MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages) { + if (payloadIsAvailableOnlyAfterPriorCommunication(payload, sourceClass)) + return true; + + for (size_t index = 0, end = messages.size(); index < end; ++index) + if (shouldDelayScalarSendUntilAfterReceives( + payload, messages.sourceCoreIds[index], messages.targetCoreIds[index])) + return true; + return false; +} + +LogicalResult emitGloballyOrderedScalarSourceFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + const ScalarSourceFanoutPlan& plan, + Location loc) { + SmallVector events = collectGloballyOrderedScalarFanoutEvents(plan); + + for (const GloballyOrderedScalarFanoutEvent& event : events) { + const ScalarSourceReceivePlan& planEntry = plan.receivePlans[event.receivePlanIndex]; + MaterializedClass& targetClass = state.classes[planEntry.targetClass]; + + if (failed(emitGloballyOrderedScalarFanoutSend(state, sourceClass, payload, planEntry, loc))) + return failure(); + + if (!targetClass.isBatch && !planEntry.projectedExtractOp) { + recordPendingScalarReceive(state, targetClass.id, keys, planEntry.receiveType, planEntry.messages, loc); + continue; + } + + bool lateReceive = shouldPlaceMatchingScalarFanoutReceiveLate(sourceClass, payload, planEntry.messages); + Value received = appendReceive(state, targetClass, planEntry.receiveType, planEntry.messages, loc, lateReceive); + + if (planEntry.projectedExtractOp) { + state.projectedExtractReplacements[planEntry.projectedExtractOp][planEntry.targetClass] = + ProjectedExtractReplacement {received, planEntry.projectedLayout}; + continue; + } + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + } + + return success(); +} + LogicalResult emitScalarSourceCommunication( MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); @@ -2848,6 +4929,9 @@ LogicalResult emitScalarSourceCommunication( auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); if (failed(fanoutPlan)) return failure(); + if (pimMaterializeScalarFanoutGlobalOrder) + return emitGloballyOrderedScalarSourceFanout(state, sourceClass, keys, payload, *fanoutPlan, loc); + if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) return failure(); @@ -2869,6 +4953,112 @@ LogicalResult emitScalarSourceCommunication( return success(); } +FailureOr emitOrderedBatchToBatchCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(sourceClass.isBatch && targetClass.isBatch && "ordered batch communication expects two batch classes"); + if (failed(messages.verify(sourceClass.op))) + return failure(); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape()) + return sourceClass.op->emitError("ordered batch communication expects a static ranked tensor payload"); + + auto makeEmpty = [&](MaterializedClass& materializedClass) -> Value { + return tensor::EmptyOp::create( + state.rewriter, loc, payloadType.getShape(), payloadType.getElementType()) + .getResult(); + }; + + setInsertionPointForEarlyCommunication(state, sourceClass); + Value sendChannelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); + Value sendSourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); + Value sendTargetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); + Value sendEarlyCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sle, + sendSourceCoreId, + sendTargetCoreId) + .getResult(); + auto earlySendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendEarlyCond, /*withElseRegion=*/false); + state.rewriter.setInsertionPoint(earlySendIf.thenBlock()->getTerminator()); + auto earlySend = SpatChannelSendOp::create( + state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); + markScalarCommunication( + earlySend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlySend"); + + setInsertionPointForLateCommunication(state, sourceClass); + Value sendLateCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sgt, + sendSourceCoreId, + sendTargetCoreId) + .getResult(); + auto lateSendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendLateCond, /*withElseRegion=*/false); + rememberLateCommunicationOp(state, sourceClass, lateSendIf.getOperation()); + state.rewriter.setInsertionPoint(lateSendIf.thenBlock()->getTerminator()); + auto lateSend = SpatChannelSendOp::create( + state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); + markScalarCommunication( + lateSend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateSend"); + + setInsertionPointForEarlyCommunication(state, targetClass); + Value recvChannelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); + Value recvSourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); + Value recvTargetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); + Value recvEarlyCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sle, + recvSourceCoreId, + recvTargetCoreId) + .getResult(); + auto earlyReceiveIf = scf::IfOp::create( + state.rewriter, loc, TypeRange {payload.getType()}, recvEarlyCond, /*withElseRegion=*/true); + Operation* earlyThenYield = earlyReceiveIf.thenBlock()->getTerminator(); + state.rewriter.setInsertionPoint(earlyThenYield); + auto earlyReceive = SpatChannelReceiveOp::create( + state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); + markScalarCommunication( + earlyReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlyReceive"); + Value earlyReceived = earlyReceive.getOutput(); + state.rewriter.modifyOpInPlace(earlyThenYield, [&] { earlyThenYield->setOperands(ValueRange {earlyReceived}); }); + Operation* earlyElseYield = earlyReceiveIf.elseBlock()->getTerminator(); + state.rewriter.setInsertionPoint(earlyElseYield); + Value empty = makeEmpty(targetClass); + state.rewriter.modifyOpInPlace(earlyElseYield, [&] { earlyElseYield->setOperands(ValueRange {empty}); }); + + setInsertionPointForLateCommunication(state, targetClass); + Value recvLateCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sgt, + recvSourceCoreId, + recvTargetCoreId) + .getResult(); + auto lateReceiveIf = scf::IfOp::create( + state.rewriter, loc, TypeRange {payload.getType()}, recvLateCond, /*withElseRegion=*/true); + rememberLateCommunicationOp(state, targetClass, lateReceiveIf.getOperation()); + Operation* lateThenYield = lateReceiveIf.thenBlock()->getTerminator(); + state.rewriter.setInsertionPoint(lateThenYield); + auto lateReceive = SpatChannelReceiveOp::create( + state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); + markScalarCommunication( + lateReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateReceive"); + Value lateReceived = lateReceive.getOutput(); + state.rewriter.modifyOpInPlace(lateThenYield, [&] { lateThenYield->setOperands(ValueRange {lateReceived}); }); + Operation* lateElseYield = lateReceiveIf.elseBlock()->getTerminator(); + state.rewriter.modifyOpInPlace( + lateElseYield, [&] { lateElseYield->setOperands(ValueRange {earlyReceiveIf.getResult(0)}); }); + + return lateReceiveIf.getResult(0); +} + LogicalResult emitClassToClassCommunication(MaterializerState& state, MaterializedClass& sourceClass, MaterializedClass& targetClass, @@ -2885,11 +5075,6 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); if (!targetClass.isBatch) { - std::optional packedKey = getContiguousProducerRangeForKeys(keys); - if (!packedKey) - return sourceClass.op->emitError( - "cannot materialize batch-to-scalar communication because source lanes are not contiguous"); - MessageVector messages; messages.channelIds.reserve(sourceClass.cpus.size()); messages.sourceCoreIds.reserve(sourceClass.cpus.size()); @@ -2936,12 +5121,13 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); } - if (failed(appendSend(state, sourceClass, payload, messages, loc))) + FailureOr received = + emitOrderedBatchToBatchCommunication(state, sourceClass, targetClass, payload, messages, loc); + if (failed(received)) return failure(); - Value received = appendReceive(state, targetClass, payload.getType(), messages, loc); for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); + state.availableValues.record(key, targetClass.id, *received); return success(); } @@ -2950,10 +5136,16 @@ LogicalResult setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); if (resultIt == sourceClass.hostOutputToResultIndex.end()) - return sourceClass.op->emitError("missing host result slot for materialized output"); + return sourceClass.op->emitError("missing host result slot for materialized output") + << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") + << " hostOutputs=" << sourceClass.hostOutputs.size() + << " originalDef=" << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() + : StringRef("")); unsigned resultIndex = resultIt->second; - state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); + if (payload.getType() != originalOutput.getType()) + return sourceClass.op->emitError("cannot set host output from fragment payload without projection") + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); if (!sourceClass.isBatch) { auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); @@ -2963,10 +5155,11 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val return sourceClass.op->emitError("host result index out of range for materialized compute"); state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); return success(); } - auto batch = cast(sourceClass.op); + auto batch = cast(sourceClass.op); auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); if (!inParallelOp) return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); @@ -2986,15 +5179,852 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); return success(); } +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); + +LogicalResult emitProjectedBatchHostOutput(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value originalOutput, + Value payload, + Location loc) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("projected batch host publication expects a batch owner class"); + auto batch = cast(sourceClass.op); + + auto ownerIt = sourceClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == sourceClass.hostOutputToResultIndex.end()) + return sourceClass.op->emitError("missing host result slot for projected batch output"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return sourceClass.op->emitError("projected batch host publication expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for publication"); + + auto sourceLaneArg = sourceBatch.getLaneArgument(); + if (!sourceLaneArg) + return sourceBatch.emitOpError("missing source compute_batch lane argument for host projection"); + + // The projection coordinates are part of the source batch publication. + // Build any affine/index helper ops in the source batch body, not at the + // caller's current insertion point. Otherwise a scalar host-owner body may + // accidentally capture the source scheduled_compute_batch lane argument. + OpBuilder::InsertionGuard projectionGuard(state.rewriter); + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + FailureOr projectionLaneValue = createProjectionLaneValueForKeys(state, sourceClass, keys, loc); + if (failed(projectionLaneValue)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(projection->getMixedOffsets().size()); + sizes.reserve(projection->getMixedSizes().size()); + strides.reserve(projection->getMixedStrides().size()); + + for (OpFoldResult offset : projection->getMixedOffsets()) { + FailureOr remapped = + remapProjectionIndexLike(state, sourceClass.op, offset, *sourceLaneArg, *projectionLaneValue, loc); + if (failed(remapped)) + return sourceClass.op->emitError("failed to remap projected batch host offsets"); + offsets.push_back(*remapped); + } + for (OpFoldResult size : projection->getMixedSizes()) { + FailureOr remapped = + remapProjectionIndexLike(state, sourceClass.op, size, *sourceLaneArg, *projectionLaneValue, loc); + if (failed(remapped)) + return sourceClass.op->emitError("failed to remap projected batch host sizes"); + sizes.push_back(*remapped); + } + for (OpFoldResult stride : projection->getMixedStrides()) { + FailureOr remapped = + remapProjectionIndexLike(state, sourceClass.op, stride, *sourceLaneArg, *projectionLaneValue, loc); + if (failed(remapped)) + return sourceClass.op->emitError("failed to remap projected batch host strides"); + strides.push_back(*remapped); + } + + auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); + if (!inParallelOp) + return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); + + auto outputArg = batch.getOutputArgument(ownerIt->second); + if (!outputArg) + return batch.emitOpError("missing host output block argument for projected batch publication"); + + state.hostReplacements[originalOutput] = sourceClass.op->getResult(ownerIt->second); + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + tensor::ParallelInsertSliceOp::create(state.rewriter, loc, payload, *outputArg, offsets, sizes, strides); + return success(); +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); + +FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { + if (value == laneArg) + return static_cast(lane); + + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return failure(); + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); + if (failed(evaluated)) + return failure(); + operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); + } + + SmallVector results; + if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) + return failure(); + + auto intAttr = dyn_cast(results.front()); + if (!intAttr) + return failure(); + return intAttr.getInt(); +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { + if (auto attr = llvm::dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return failure(); + return intAttr.getInt(); + } + return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return failure(); + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return failure(); + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (candidateIndex == resultIndex) + return insert; + } + + return failure(); +} + +FailureOr> +evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { + SmallVector evaluated; + evaluated.reserve(values.size()); + for (OpFoldResult value : values) { + FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); + if (failed(index)) + return failure(); + evaluated.push_back(*index); + } + return evaluated; +} + + +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane) { + auto producerBatch = dyn_cast_or_null(producer.instance.op); + if (!producerBatch) + return true; + + FailureOr producerProjection = + getBatchResultProjectionInsert(producerBatch, producer.resultIndex); + if (failed(producerProjection)) + return true; + + std::optional producerLaneArg = producerBatch.getLaneArgument(); + std::optional consumerLaneArg = consumerBatch.getLaneArgument(); + if (!producerLaneArg || !consumerLaneArg) + return false; + + SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); + SmallVector loopIterationIndices(match.loops.size(), 0); + + const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { + SmallVector consumerOffsets; + consumerOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return false; + consumerOffsets.push_back(*evaluated); + } + + uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; + for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { + FailureOr> producerOffsets = + evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); + FailureOr> producerSizes = + evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); + FailureOr> producerStrides = + evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); + if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) + return false; + if (!areAllUnitStrides(*producerStrides)) + return false; + if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) + return true; + } + + return false; + }; + + if (match.loops.empty()) + return consumerSliceFitsOneProducerFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { + if (loopIndex == match.loops.size()) + return consumerSliceFitsOneProducerFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (!self(self, loopIndex + 1)) + return false; + } + return true; + }; + + return recurse(recurse, 0); +} + +LogicalResult insertProjectedBatchHostFragment(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + uint32_t lane, + Value payload) { + if (ownerClass.isBatch) + return ownerClass.op->emitError("projected batch host fallback expects a scalar owner class"); + + auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == ownerClass.hostOutputToResultIndex.end()) + return ownerClass.op->emitError("missing host result slot for projected batch host fragment"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return ownerClass.op->emitError("projected batch host fallback expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for materialization"); + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument for host projection"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return ownerClass.op->emitError("failed to evaluate batch host projection coordinates"); + + auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); + if (!yieldOp) + return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); + + unsigned hostResultIndex = ownerIt->second; + if (hostResultIndex >= yieldOp.getNumOperands()) + return ownerClass.op->emitError("host result index out of range for projected batch host fragment"); + if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) + return ownerClass.op->emitError("projected batch host fragment expected a full host accumulator tensor") + << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() + << " outputType=" << originalOutput.getType(); + + state.rewriter.setInsertionPoint(yieldOp); + Value updated = tensor::InsertSliceOp::create(state.rewriter, + payload.getLoc(), + payload, + yieldOp.getOperand(hostResultIndex), + ValueRange {}, + ValueRange {}, + ValueRange {}, + *offsets, + *sizes, + *strides) + .getResult(); + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, updated); }); + state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); + return success(); +} + + +LogicalResult emitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + ArrayRef keys, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + if (ownerClass.isBatch) + return ownerClass.op->emitError("projected batch host receive loop expects a scalar owner class"); + if (keys.empty()) + return success(); + if (keys.size() != messages.size()) + return ownerClass.op->emitError("projected batch host receive loop message metadata is inconsistent"); + + auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == ownerClass.hostOutputToResultIndex.end()) + return ownerClass.op->emitError("missing host result slot for projected batch host receive loop"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return ownerClass.op->emitError("projected batch host receive loop expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for receive loop"); + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument for projected host receive loop"); + + auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); + if (!yieldOp) + return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); + + unsigned hostResultIndex = ownerIt->second; + if (hostResultIndex >= yieldOp.getNumOperands()) + return ownerClass.op->emitError("host result index out of range for projected batch host receive loop"); + if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) + return ownerClass.op->emitError("projected batch host receive loop expected a full host accumulator tensor") + << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() + << " outputType=" << originalOutput.getType(); + + unsigned rank = projection->getMixedOffsets().size(); + SmallVector, 4> offsetsByDim(rank); + SmallVector, 4> sizesByDim(rank); + SmallVector, 4> stridesByDim(rank); + for (ProducerKey key : keys) { + if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() + || key.instance.laneCount != 1) + return ownerClass.op->emitError("projected batch host receive loop expects one-lane fragments from one output"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) + return ownerClass.op->emitError("failed to evaluate projected batch host receive loop coordinates"); + if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) + return ownerClass.op->emitError("projected batch host receive loop coordinate rank mismatch"); + + for (unsigned dim = 0; dim < rank; ++dim) { + offsetsByDim[dim].push_back((*offsets)[dim]); + sizesByDim[dim].push_back((*sizes)[dim]); + stridesByDim[dim].push_back((*strides)[dim]); + } + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); + + state.rewriter.setInsertionPoint(yieldOp); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {yieldOp.getOperand(hostResultIndex)}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value channelId = createIndexedChannelId(state, ownerClass.op, messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, ownerClass.op, messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, ownerClass.op, messages, flatIndex, loc); + Value fragment = SpatChannelReceiveOp::create( + state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); + } + + Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + markScalarCommunication( + loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitProjectedBatchHostReceiveInsertLoop"); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); + state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); + return success(); +} + +std::optional tryEmitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + ArrayRef keys, + Location loc) { + if (keys.empty()) + return success(); + + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(keys.front(), ownerClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + for (size_t runIndex : runIndices) { + PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); + if (run.kind != PackedScalarRunKind::DeferredReceive) + continue; + SmallVector runKeys = flattenPackedScalarRunKeys(run); + if (!llvm::equal(runKeys, keys)) + continue; + return emitProjectedBatchHostReceiveInsertLoop( + state, ownerClass, originalOutput, runKeys, run.fragmentType, run.messages, loc); + } + + return std::nullopt; +} + +FailureOr getLeadingPackedFragmentType(Operation* anchor, Value payload, size_t fragmentCount) { + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return failure(); + if (payloadType.getDimSize(0) != static_cast(fragmentCount)) + return failure(); + + SmallVector fragmentShape(payloadType.getShape().begin(), payloadType.getShape().end()); + fragmentShape[0] = 1; + return RankedTensorType::get(fragmentShape, payloadType.getElementType(), payloadType.getEncoding()); +} + +LogicalResult emitScalarPackedProjectedHostSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "packed projected host send loop expects a scalar source"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return sourceClass.op->emitError("packed projected host send loop expects a static ranked payload"); + + setInsertionPointForScalarCommunication(state, sourceClass, getMinimumChannelId(messages.channelIds)); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(payloadType.getRank()); + sizes.reserve(payloadType.getRank()); + strides.reserve(payloadType.getRank()); + offsets.push_back(index); + sizes.push_back(state.rewriter.getIndexAttr(1)); + strides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { + offsets.push_back(state.rewriter.getIndexAttr(0)); + sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + Value fragment = tensor::ExtractSliceOp::create( + state.rewriter, loc, fragmentType, payload, offsets, sizes, strides) + .getResult(); + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, fragment); + return success(); + }); + if (failed(loop)) + return failure(); + markScalarCommunication( + loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitScalarPackedProjectedHostSendLoop"); + return success(); +} + +LogicalResult emitScalarPackedProjectedHostLocalInsertLoop(MaterializerState& state, + MaterializedClass& ownerClass, + ArrayRef keys, + Value payload, + Value originalOutput, + RankedTensorType fragmentType, + Location loc) { + if (ownerClass.isBatch) + return ownerClass.op->emitError("packed projected host local insert loop expects a scalar owner class"); + if (keys.empty()) + return success(); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return ownerClass.op->emitError("packed projected host local insert loop expects a static ranked payload"); + if (payloadType.getDimSize(0) != static_cast(keys.size())) + return ownerClass.op->emitError("packed projected host local insert loop payload/key count mismatch"); + + auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == ownerClass.hostOutputToResultIndex.end()) + return ownerClass.op->emitError("missing host result slot for packed projected host local insert loop"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return ownerClass.op->emitError("packed projected host local insert loop expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for local insert loop"); + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument for packed projected host local insert loop"); + + auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); + if (!yieldOp) + return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); + + unsigned hostResultIndex = ownerIt->second; + if (hostResultIndex >= yieldOp.getNumOperands()) + return ownerClass.op->emitError("host result index out of range for packed projected host local insert loop"); + if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) + return ownerClass.op->emitError("packed projected host local insert loop expected a full host accumulator tensor") + << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() + << " outputType=" << originalOutput.getType(); + + unsigned rank = projection->getMixedOffsets().size(); + SmallVector, 4> offsetsByDim(rank); + SmallVector, 4> sizesByDim(rank); + SmallVector, 4> stridesByDim(rank); + for (ProducerKey key : keys) { + if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() + || key.instance.laneCount != 1) + return ownerClass.op->emitError("packed projected host local insert loop expects one-lane fragments from one output"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) + return ownerClass.op->emitError("failed to evaluate packed projected host local insert loop coordinates"); + if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) + return ownerClass.op->emitError("packed projected host local insert loop coordinate rank mismatch"); + + for (unsigned dim = 0; dim < rank; ++dim) { + offsetsByDim[dim].push_back((*offsets)[dim]); + sizesByDim[dim].push_back((*sizes)[dim]); + stridesByDim[dim].push_back((*strides)[dim]); + } + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); + + state.rewriter.setInsertionPoint(yieldOp); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {yieldOp.getOperand(hostResultIndex)}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(payloadType.getRank()); + extractSizes.reserve(payloadType.getRank()); + extractStrides.reserve(payloadType.getRank()); + extractOffsets.push_back(flatIndex); + extractSizes.push_back(state.rewriter.getIndexAttr(1)); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { + extractOffsets.push_back(state.rewriter.getIndexAttr(0)); + extractSizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + } + + Value fragment = tensor::ExtractSliceOp::create( + state.rewriter, loc, fragmentType, payload, extractOffsets, extractSizes, extractStrides) + .getResult(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); + } + + Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); + state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); + return success(); +} + +std::optional tryEmitScalarPackedProjectedHostPublication(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (sourceClass.isBatch || keys.size() <= 1) + return std::nullopt; + + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for projected batch output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (ownerClass.isBatch) + return ownerClass.op->emitError( + "projected batch host output reached a batch owner without an explicit batch publication path"); + FailureOr fragmentType = getLeadingPackedFragmentType(sourceClass.op, payload, keys.size()); + if (failed(fragmentType)) + return std::nullopt; + + if (ownerClass.id == sourceClass.id) + return emitScalarPackedProjectedHostLocalInsertLoop( + state, ownerClass, keys, payload, originalOutput, *fragmentType, loc); + + auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); + auto targetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); + if (failed(sourceCpu) || failed(targetCpu)) + return failure(); + + MessageVector messages; + for ([[maybe_unused]] ProducerKey key : keys) + messages.append(state.nextChannelId++, *sourceCpu, *targetCpu); + + if (failed(messages.verify(sourceClass.op))) + return failure(); + + if (failed(emitScalarPackedProjectedHostSendLoop(state, sourceClass, payload, *fragmentType, messages, loc))) + return failure(); + + return emitProjectedBatchHostReceiveInsertLoop( + state, ownerClass, originalOutput, keys, *fragmentType, messages, loc); +} + +void appendPendingProjectedHostReceive(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + ProducerKey key, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + assert(messages.size() == 1 && "pending projected host receive records one message at a time"); + for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { + if (group.originalOutput != originalOutput || group.ownerClassId != ownerClass.id || group.fragmentType != fragmentType) + continue; + group.keys.push_back(key); + group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); + return; + } + + PendingProjectedHostReceiveGroup group { + originalOutput, + ownerClass.id, + fragmentType, + SmallVector{key}, + MessageVector{}, + loc + }; + group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); + state.pendingProjectedHostReceives.push_back(std::move(group)); +} + +LogicalResult flushPendingProjectedHostReceives(MaterializerState& state) { + for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { + if (group.ownerClassId >= state.classes.size()) + return state.func.emitError("pending projected host receive has invalid owner class"); + MaterializedClass& ownerClass = state.classes[group.ownerClassId]; + if (failed(group.messages.verify(ownerClass.op))) + return failure(); + if (group.keys.empty()) + continue; + if (failed(emitProjectedBatchHostReceiveInsertLoop( + state, ownerClass, group.originalOutput, group.keys, group.fragmentType, group.messages, group.loc))) + return failure(); + } + state.pendingProjectedHostReceives.clear(); + return success(); +} + +LogicalResult emitProjectedBatchHostFragment(MaterializerState& state, + MaterializedClass& sourceClass, + ProducerKey key, + Value payload, + Value originalOutput, + Location loc) { + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for projected batch output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + Value ownerPayload = payload; + if (sourceClass.id != ownerClass.id) { + if (ownerClass.isBatch) { + return ownerClass.op->emitError( + "projected batch host fragment reached a batch owner without an explicit batch publication path"); + } + + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); + if (failed(checkedTargetCpu)) + return failure(); + if (!sourceClass.isBatch) { + if (failed(checkedSourceCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + if (failed(appendSend(state, sourceClass, payload, messages, loc))) + return failure(); + + auto fragmentType = dyn_cast(payload.getType()); + if (!fragmentType) + return sourceClass.op->emitError("projected terminal batch host fragment expects ranked tensor payload"); + appendPendingProjectedHostReceive(state, ownerClass, originalOutput, key, fragmentType, messages, loc); + return success(); + } + else { + ComputeInstance scheduledInstance = getScheduledChunkForLogicalInstance(state, key.instance); + auto sourceCpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); + if (sourceCpuIt == state.schedule.computeToCpuMap.end()) + return sourceClass.op->emitError("missing CPU assignment for projected batch host source"); + + auto localLaneIt = sourceClass.cpuToLane.find(sourceCpuIt->second); + if (localLaneIt == sourceClass.cpuToLane.end()) + return sourceClass.op->emitError("missing local batch lane for projected batch host source"); + + if (failed(checkedSourceCpu = getCheckedCoreId(sourceClass.op, + sourceCpuIt->second, + "projected host source core id"))) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + auto batch = cast(sourceClass.op); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing lane argument for projected batch host source"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value localLane = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, localLaneIt->second); + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); + Value isSourceLane = arith::CmpIOp::create(state.rewriter, loc, arith::CmpIPredicate::eq, *laneArg, localLane); + auto ifOp = scf::IfOp::create(state.rewriter, loc, TypeRange {}, isSourceLane, /*withElseRegion=*/false); + state.rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator()); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, loc); + } + } + + return insertProjectedBatchHostFragment(state, ownerClass, originalOutput, key.instance.laneStart, ownerPayload); +} + LogicalResult emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) { if (!hasLiveExternalUseCached(state, originalOutput)) return success(); - return setHostOutputValue(state, sourceClass, originalOutput, payload); + if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) + return sourceClass.op->emitError("cannot set projected terminal batch host output through the generic host path"); + + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for live external output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (sourceClass.id == ownerClass.id) + return setHostOutputValue(state, ownerClass, originalOutput, payload); + + if (sourceClass.isBatch) + return sourceClass.op->emitError("batch host publication must be routed through a projection-aware or owning path"); + if (ownerClass.isBatch) + return ownerClass.op->emitError("generic host publication does not support batch host owners"); + + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "host target core id"); + if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + if (failed(appendSend(state, sourceClass, payload, messages, payload.getLoc()))) + return failure(); + Value ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, payload.getLoc()); + return setHostOutputValue(state, ownerClass, originalOutput, ownerPayload); } LogicalResult emitOutputFanout(MaterializerState& state, @@ -3010,6 +6040,21 @@ LogicalResult emitOutputFanout(MaterializerState& state, if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) return failure(); + if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { + std::optional loopedHostPublication = + tryEmitScalarPackedProjectedHostPublication(state, sourceClass, keys, payload, originalOutput, loc); + if (loopedHostPublication) + return *loopedHostPublication; + + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("projected terminal batch host output expects one logical lane per fragment"); + if (failed(emitProjectedBatchHostFragment(state, sourceClass, key, payload, originalOutput, loc))) + return failure(); + } + return success(); + } + return emitHostCommunication(state, sourceClass, payload, originalOutput); } @@ -3017,12 +6062,62 @@ LogicalResult emitOutputFanout(MaterializerState& state, return sourceClass.op->emitError( "cannot materialize batched output whose lanes have different destination equivalence classes"); + if (auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp())) { + if (sourceBatch.getNumResults() != 0 && isTerminalHostBatchOutput(originalOutput, state.oldComputeOps)) { + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) + if (!state.classes[destinationClass].isBatch) + return emitBatchToScalarDestinationDiagnostic(state, sourceClass, keys, originalOutput); + } + } + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); - if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + if (sourceBatch && sourceBatch.getNumResults() != 0 && hasLiveExternalUseCached(state, originalOutput)) { + if (sourceClass.hostOutputToResultIndex.contains(originalOutput)) { + if (failed(emitProjectedBatchHostOutput(state, sourceClass, keys, originalOutput, payload, loc))) + return failure(); + } + else { + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for projected batch output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (ownerClass.isBatch) + return ownerClass.op->emitError( + "projected batch host output reached a batch owner without an explicit batch publication path"); + + if (sourceClass.id != ownerClass.id + && failed(emitClassToClassCommunication(state, sourceClass, ownerClass, keys, payload, loc))) + return failure(); + + std::optional loopedHostPublication = + tryEmitProjectedBatchHostReceiveInsertLoop(state, ownerClass, originalOutput, keys, loc); + if (loopedHostPublication) { + if (failed(*loopedHostPublication)) + return failure(); + } + else { + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("projected batch host output expects one logical lane per fragment"); + + std::optional ownerPayload = state.availableValues.lookup(state, key, ownerClass.id); + if (!ownerPayload) + return ownerClass.op->emitError("failed to recover projected batch host fragment after communication"); + + if (failed(insertProjectedBatchHostFragment( + state, ownerClass, originalOutput, key.instance.laneStart, *ownerPayload))) + return failure(); + } + } + } + } else if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) { return failure(); + } for (ProducerKey key : keys) state.availableValues.record(key, sourceClass.id, payload); @@ -3054,6 +6149,34 @@ struct WholeBatchFragmentGroup { RankedTensorType slotPackedType; SmallVector slotIndices; SmallVector, 16> directFragments; + SmallVector redundantReceives; +}; + +enum class ProjectedWholeBatchFragmentSourceKind { + DeferredReceive, + PackedValue, + DirectValue +}; + +struct ProjectedWholeBatchDirectFragment { + Value fragment; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +struct ProjectedWholeBatchFragmentGroup { + ProjectedWholeBatchFragmentSourceKind kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; + RankedTensorType fragmentType; + SmallVector, 4> offsetsByDim; + SmallVector, 4> sizesByDim; + SmallVector, 4> stridesByDim; + MessageVector messages; + SmallVector redundantOps; + Value packed; + RankedTensorType packedSourceType; + SmallVector packedIndices; + SmallVector directFragments; }; struct WholeBatchAssemblyPlan { @@ -3135,19 +6258,26 @@ validateWholeBatchFragmentType(RankedTensorType resultType, RankedTensorType fra // Packed run tensor assembly helpers. // ----------------------------------------------------------------------------- -Value insertFragmentIntoWholeBatch( - MaterializerState& state, Value fragment, Value destination, OpFoldResult firstOffset, Location loc) { - return createDim0InsertSlice(state, loc, fragment, destination, firstOffset); +FailureOr insertFragmentIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value fragment, + Value destination, + OpFoldResult firstOffset, + Location loc) { + return createDim0InsertSliceInClass(state, targetClass, loc, fragment, destination, firstOffset); } -Value extractPackedSlotForIndex(MaterializerState& state, - Operation* anchor, +FailureOr extractPackedSlotForIndex(MaterializerState& state, + MaterializedClass& targetClass, Value packed, RankedTensorType slotPackedType, Value slotIndex, Location loc) { - Value firstOffset = scaleIndexByDim0Size(state, anchor, slotIndex, slotPackedType.getDimSize(0), loc); - return createDim0ExtractSlice(state, loc, packed, firstOffset, slotPackedType.getDimSize(0)); + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0)); } SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { @@ -3157,17 +6287,62 @@ SmallVector flattenPackedScalarRunKeys(const PackedScalarRunVal return keys; } +bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScalarRunValue& rhs) { + if (lhs.slots.size() != rhs.slots.size()) + return false; + + for (auto [lhsSlot, rhsSlot] : llvm::zip(lhs.slots, rhs.slots)) { + if (lhsSlot.keys.size() != rhsSlot.keys.size()) + return false; + if (!llvm::equal(lhsSlot.keys, rhsSlot.keys)) + return false; + } + + return true; +} + + +bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelReceiveOp receive) { + std::optional channelId = getConstantIndexValue(receive.getChannelId()); + std::optional sourceCoreId = getConstantIndexValue(receive.getSourceCoreId()); + std::optional targetCoreId = getConstantIndexValue(receive.getTargetCoreId()); + if (!channelId || !sourceCoreId || !targetCoreId) + return false; + messages.append(*channelId, static_cast(*sourceCoreId), static_cast(*targetCoreId)); + return true; +} + +PackedScalarRunValue* findDeferredReceiveAlternativeForPackedRun(MaterializerState& state, + const MaterializedClass& targetClass, + const PackedScalarRunValue& run) { + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + + for (size_t runIndex : runIndices) { + PackedScalarRunValue& candidate = state.availableValues.getPackedRun(runIndex); + if (&candidate == &run || candidate.kind != PackedScalarRunKind::DeferredReceive) + continue; + if (candidate.fragmentType != run.fragmentType) + continue; + if (!packedScalarRunSlotsMatch(candidate, run)) + continue; + return &candidate; + } + + return nullptr; +} + FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, - Operation* anchor, - Operation* insertionPoint, + MaterializedClass& targetClass, Value destination, int64_t itemCount, IndexedFragmentBuilder buildFragment, IndexedInsertOffsetBuilder buildOffset, Location loc) { - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, itemCount); - Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, itemCount); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + Operation* insertionPoint = targetClass.body->getTerminator(); state.rewriter.setInsertionPoint(insertionPoint); auto loop = buildNormalizedScfFor( @@ -3184,7 +6359,11 @@ FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, FailureOr offset = buildOffset(flatIndex); if (failed(offset)) return failure(); - yielded.push_back(insertFragmentIntoWholeBatch(state, *fragment, iterArgs.front(), *offset, loc)); + FailureOr next = + insertFragmentIntoWholeBatch(state, targetClass, *fragment, iterArgs.front(), *offset, loc); + if (failed(next)) + return failure(); + yielded.push_back(*next); return success(); }); if (failed(loop)) @@ -3258,9 +6437,14 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& if (failed(produced) || produced->size() != 1) return failure(); - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, run.fragmentType.getDimSize(0), loc); - Value next = createDim0InsertSlice(state, loc, produced->front(), acc, firstOffset); - yielded.push_back(next); + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); return success(); }); if (failed(loop)) @@ -3339,6 +6523,8 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex || candidateKey.instance.laneCount == 0) continue; + if (!isTensorValueLocalToMaterializedClass(record.value, targetClass)) + continue; if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount)) continue; @@ -3404,6 +6590,22 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, if (run->fragmentType.getDimSize(0) != plan.rowsPerLane) return failure(); + if (run->kind == PackedScalarRunKind::Materialized && run->packed + && !isTensorValueLocalToMaterializedClass(run->packed, targetClass)) { + if (PackedScalarRunValue* deferredRun = findDeferredReceiveAlternativeForPackedRun(state, targetClass, *run)) + run = deferredRun; + else { + SmallVector keys = flattenPackedScalarRunKeys(*run); + std::optional packedKey = getContiguousProducerRangeForKeys(keys); + emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, + targetClass, + "whole-batch assembly tried to reuse non-local PackedValue", + run->packed, + packedKey); + return failure(); + } + } + if (run->kind == PackedScalarRunKind::DeferredReceive) { if (failed(validatePackedScalarRunMetadata(targetClass.op, *run))) return failure(); @@ -3459,46 +6661,103 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, if (!sourceBatch || !run->packed) return failure(); - size_t slotLaneCount = run->slots.front().keys.size(); - if (slotLaneCount == 0) - return failure(); - FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slotLaneCount); - if (failed(slotPackedType)) - return failure(); - - auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { - return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType - && group.packed == run->packed && group.slotPackedType == *slotPackedType; - }); - if (groupIt == groups.end()) { - WholeBatchFragmentGroup group; - group.kind = WholeBatchFragmentSourceKind::PackedValue; - group.fragmentType = run->fragmentType; - group.packed = run->packed; - group.slotPackedType = *slotPackedType; - groups.push_back(std::move(group)); - groupIt = std::prev(groups.end()); - } + auto getOrCreatePackedValueGroup = [&](RankedTensorType slotPackedType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType + && group.packed == run->packed && group.slotPackedType == slotPackedType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = run->fragmentType; + group.packed = run->packed; + group.slotPackedType = slotPackedType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + size_t flattenedIndexBase = 0; for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { - std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); - if (!contiguousKey) - return failure(); - groupIt->slotIndices.push_back(slotIndex); - groupIt->outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); + std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); + if (contiguousKey) { + FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return failure(); + WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType); + group.slotIndices.push_back(slotIndex); + group.outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); + flattenedIndexBase += slot.keys.size(); + continue; + } + + WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType); + for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) { + group.slotIndices.push_back(flattenedIndexBase + keyIndex); + group.outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + } + flattenedIndexBase += slot.keys.size(); } } + auto getOrCreateDeferredReceiveGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + auto getOrCreateDirectValueGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + for (const DirectWholeBatchFragment& fragment : plan.directFragments) { - WholeBatchFragmentGroup group; + if (!isTensorValueLocalToMaterializedClass(fragment.fragment, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, + targetClass, + "whole-batch assembly tried to reuse non-local DirectValue", + fragment.fragment, + fragment.key); + return failure(); + } + auto fragmentType = dyn_cast(fragment.fragment.getType()); if (!fragmentType) return failure(); - group.kind = WholeBatchFragmentSourceKind::DirectValue; - group.fragmentType = fragmentType; - group.directFragments.push_back( - {fragment.fragment, static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane}); - groups.push_back(std::move(group)); + + int64_t outputOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; + + if (auto receive = fragment.fragment.getDefiningOp()) { + if (fragment.fragment.use_empty()) { + WholeBatchFragmentGroup& group = getOrCreateDeferredReceiveGroup(fragmentType); + if (appendConstantChannelReceiveMessage(group.messages, receive)) { + group.outputOffsets.push_back(outputOffset); + group.redundantReceives.push_back(receive.getOperation()); + continue; + } + } + } + + WholeBatchFragmentGroup& group = getOrCreateDirectValueGroup(fragmentType); + group.directFragments.push_back({fragment.fragment, outputOffset}); } return success(); @@ -3510,11 +6769,10 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, const WholeBatchFragmentGroup& group, Location loc) { switch (group.kind) { - case WholeBatchFragmentSourceKind::DeferredReceive: - return emitIndexedFragmentInsertLoop( + case WholeBatchFragmentSourceKind::DeferredReceive: { + FailureOr updated = emitIndexedFragmentInsertLoop( state, - targetClass.op, - targetClass.body->getTerminator(), + targetClass, destination, static_cast(group.outputOffsets.size()), [&](Value flatIndex) -> FailureOr { @@ -3529,12 +6787,20 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); }, loc); + if (failed(updated)) + return failure(); + + for (Operation* receive : group.redundantReceives) + if (receive && receive->use_empty()) + receive->erase(); + + return *updated; + } case WholeBatchFragmentSourceKind::DeferredLocalCompute: { SmallVector resultIndices {group.resultIndex}; return emitIndexedFragmentInsertLoop( state, - targetClass.op, - targetClass.body->getTerminator(), + targetClass, destination, static_cast(group.outputOffsets.size()), [&](Value flatIndex) -> FailureOr { @@ -3558,14 +6824,20 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, case WholeBatchFragmentSourceKind::PackedValue: return emitIndexedFragmentInsertLoop( state, - targetClass.op, - targetClass.body->getTerminator(), + targetClass, destination, static_cast(group.slotIndices.size()), [&](Value flatIndex) -> FailureOr { Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); - return extractPackedSlotForIndex( - state, targetClass.op, group.packed, group.slotPackedType, packedSlotIndex, loc); + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + group.packed, + targetClass.op, + "whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc); }, [&](Value flatIndex) -> FailureOr { return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); @@ -3574,8 +6846,23 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, case WholeBatchFragmentSourceKind::DirectValue: for (const auto& [fragment, offset] : group.directFragments) { state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - destination = insertFragmentIntoWholeBatch( - state, fragment, destination, getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset), loc); + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment, + targetClass.op, + "whole-batch direct fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(localFragment)) + return failure(); + FailureOr updated = createDim0InsertSliceInClass(state, + targetClass, + loc, + *localFragment, + destination, + getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); + if (failed(updated)) + return failure(); + destination = *updated; } return destination; } @@ -3583,6 +6870,100 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, return failure(); } +FailureOr emitProjectedWholeBatchFragmentInsertLoop( + MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const ProjectedWholeBatchFragmentGroup& group, + llvm::function_ref(Value)> buildFragment, + Location loc) { + assert(group.fragmentType && "expected projected fragment type"); + assert(!group.offsetsByDim.empty() && "expected projected insert coordinates"); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + FailureOr fragment = buildFragment(flatIndex); + if (failed(fragment)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + unsigned rank = group.offsetsByDim.size(); + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.stridesByDim[dim], flatIndex, loc)); + } + + Value updated = + tensor::InsertSliceOp::create(state.rewriter, loc, *fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +std::optional getStaticProjectedPackedFragmentIndex(tensor::ExtractSliceOp extract) { + auto sourceType = dyn_cast(extract.getSource().getType()); + auto resultType = dyn_cast(extract.getResult().getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() + || sourceType.getRank() == 0 || sourceType.getRank() != resultType.getRank()) + return std::nullopt; + + std::optional firstOffset = getConstantIndex(extract.getMixedOffsets().front()); + if (!firstOffset) + return std::nullopt; + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + std::optional offset = getConstantIndex(extract.getMixedOffsets()[dim]); + std::optional size = getConstantIndex(extract.getMixedSizes()[dim]); + std::optional stride = getConstantIndex(extract.getMixedStrides()[dim]); + if (!offset || !size || !stride || *stride != 1 || *size != resultType.getDimSize(dim)) + return std::nullopt; + if (dim != 0 && *offset != 0) + return std::nullopt; + } + + return *firstOffset; +} + +void appendProjectedInsertCoordinates(ProjectedWholeBatchFragmentGroup& group, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + if (group.offsetsByDim.empty()) { + size_t rank = offsets.size(); + group.offsetsByDim.resize(rank); + group.sizesByDim.resize(rank); + group.stridesByDim.resize(rank); + } + + for (size_t dim = 0; dim < offsets.size(); ++dim) { + group.offsetsByDim[dim].push_back(offsets[dim]); + group.sizesByDim[dim].push_back(sizes[dim]); + group.stridesByDim[dim].push_back(strides[dim]); + } +} + FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, @@ -3643,13 +7024,301 @@ FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, // Run materialization helpers. // ----------------------------------------------------------------------------- -FailureOr materializeWholeBatchInput( - MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { - FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); - if (failed(plan)) +FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType, + Location loc) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || !resultTensorType.hasStaticShape()) return failure(); - return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); + FailureOr projection = getBatchResultProjectionInsert(batch, key.resultIndex); + if (failed(projection)) + return failure(); + + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing compute_batch lane argument while materializing projected whole-batch input"); + + uint32_t laneEnd = key.instance.laneStart + key.instance.laneCount; + if (laneEnd > static_cast(batch.getLaneCount())) + return failure(); + + if (targetClass.isBatch) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) + .getResult(); + + for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { + ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); + std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); + if (!fragment) + return failure(); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + SmallVector offsetAttrs; + SmallVector sizeAttrs; + SmallVector strideAttrs; + offsetAttrs.reserve(offsets->size()); + sizeAttrs.reserve(sizes->size()); + strideAttrs.reserve(strides->size()); + for (auto [offset, size, stride] : llvm::zip(*offsets, *sizes, *strides)) { + offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); + sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); + strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); + } + + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + *fragment, + targetClass.op, + "projected whole-batch assembly tried to reuse a tensor from another materialized class", + laneKey); + if (failed(localFragment)) + return failure(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + result = tensor::InsertSliceOp::create( + state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); + } + + state.availableValues.record(key, targetClass.id, result); + return result; + } + + SmallVector groups; + auto getOrCreateReceiveGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + auto getOrCreatePackedGroup = [&](Value packed, + RankedTensorType packedSourceType, + RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::PackedValue && group.fragmentType == fragmentType + && group.packed == packed && group.packedSourceType == packedSourceType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = fragmentType; + group.packed = packed; + group.packedSourceType = packedSourceType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + auto getOrCreateDirectGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { + ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + bool grouped = false; + if (std::optional exact = state.availableValues.lookupExact(laneKey, targetClass.id)) { + if (auto receive = exact->getDefiningOp()) { + auto fragmentType = dyn_cast(receive.getOutput().getType()); + if (fragmentType && receive.getOutput().use_empty()) { + ProjectedWholeBatchFragmentGroup& group = getOrCreateReceiveGroup(fragmentType); + if (appendConstantChannelReceiveMessage(group.messages, receive)) { + appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); + group.redundantOps.push_back(receive.getOperation()); + grouped = true; + } + } + } + } + + if (grouped) + continue; + + std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); + if (!fragment) + return failure(); + + auto fragmentType = dyn_cast(fragment->getType()); + if (!fragmentType) + return failure(); + + if (auto extract = fragment->getDefiningOp()) { + if (std::optional packedIndex = getStaticProjectedPackedFragmentIndex(extract)) { + auto packedSourceType = dyn_cast(extract.getSource().getType()); + if (packedSourceType) { + ProjectedWholeBatchFragmentGroup& group = + getOrCreatePackedGroup(extract.getSource(), packedSourceType, fragmentType); + group.packedIndices.push_back(*packedIndex); + appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); + group.redundantOps.push_back(extract.getOperation()); + continue; + } + } + } + + ProjectedWholeBatchFragmentGroup& group = getOrCreateDirectGroup(fragmentType); + group.directFragments.push_back({*fragment, *offsets, *sizes, *strides}); + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) + .getResult(); + + for (const ProjectedWholeBatchFragmentGroup& group : groups) { + FailureOr updated = failure(); + switch (group.kind) { + case ProjectedWholeBatchFragmentSourceKind::DeferredReceive: + updated = emitProjectedWholeBatchFragmentInsertLoop( + state, + targetClass, + result, + group, + [&](Value flatIndex) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); + return SpatChannelReceiveOp::create( + state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + loc); + break; + case ProjectedWholeBatchFragmentSourceKind::PackedValue: + updated = emitProjectedWholeBatchFragmentInsertLoop( + state, + targetClass, + result, + group, + [&](Value flatIndex) -> FailureOr { + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(group.packedSourceType.getRank()); + extractSizes.reserve(group.packedSourceType.getRank()); + extractStrides.reserve(group.packedSourceType.getRank()); + extractOffsets.push_back(createIndexedOrStaticIndex( + state, targetClass.op, group.packedIndices, flatIndex, loc)); + extractSizes.push_back(state.rewriter.getIndexAttr(1)); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < group.packedSourceType.getRank(); ++dim) { + extractOffsets.push_back(state.rewriter.getIndexAttr(0)); + extractSizes.push_back(state.rewriter.getIndexAttr(group.packedSourceType.getDimSize(dim))); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + } + + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + group.packed, + targetClass.op, + "projected whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + + return tensor::ExtractSliceOp::create( + state.rewriter, + loc, + group.fragmentType, + *packed, + extractOffsets, + extractSizes, + extractStrides) + .getResult(); + }, + loc); + break; + case ProjectedWholeBatchFragmentSourceKind::DirectValue: { + updated = result; + for (const ProjectedWholeBatchDirectFragment& fragment : group.directFragments) { + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment.fragment, + targetClass.op, + "projected whole-batch assembly tried to reuse a tensor from another materialized class"); + if (failed(localFragment)) + return failure(); + + SmallVector offsetAttrs; + SmallVector sizeAttrs; + SmallVector strideAttrs; + for (auto [offset, size, stride] : llvm::zip(fragment.offsets, fragment.sizes, fragment.strides)) { + offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); + sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); + strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); + } + updated = tensor::InsertSliceOp::create( + state.rewriter, loc, *localFragment, *updated, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); + } + break; + } + } + if (failed(updated)) + return failure(); + result = *updated; + } + + for (const ProjectedWholeBatchFragmentGroup& group : groups) + for (Operation* redundantOp : group.redundantOps) + if (redundantOp && redundantOp->use_empty()) + redundantOp->erase(); + + state.availableValues.record(key, targetClass.id, result); + return result; +} + +FailureOr materializeWholeBatchInput( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { + if (failed(materializePendingScalarReceivesForWholeBatchInput(state, targetClass, key, loc))) + return failure(); + + FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); + if (succeeded(plan)) + return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); + + return materializeProjectedWholeBatchInputFromFragments(state, targetClass, key, resultType, loc); } FailureOr resolveInputValue(MaterializerState& state, @@ -3657,18 +7326,43 @@ FailureOr resolveInputValue(MaterializerState& state, Value input, const ComputeInstance& consumerInstance, CloneIndexingContext indexing) { + auto rejectNonLocalResolvedValue = [&](Value resolved) -> FailureOr { + if (!isTensorValueDefinedInDifferentMaterializedClass(resolved, targetClass)) + return resolved; + + std::optional producer = getInputRequestProducerKey(input, consumerInstance); + emitNonLocalMaterializedClassValueDiagnostic(consumerInstance.op, + targetClass, + "input resolution tried to reuse a tensor from another materialized class", + resolved, + producer); + return failure(); + }; + if (isConstantLike(input)) return input; if (std::optional producer = getInputRequestProducerKey(input, consumerInstance)) { if (indexing.runSlotIndex) { - if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) - return materializeIndexedBatchRunReceive( + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { + FailureOr received = materializeIndexedBatchRunReceive( state, targetClass, *indexedRun, *indexing.runSlotIndex, consumerInstance.op->getLoc()); + if (failed(received)) + return failure(); + return rejectNonLocalResolvedValue(*received); + } } if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) - return *value; + return rejectNonLocalResolvedValue(*value); + + if (auto pendingReceive = lookupPendingScalarReceiveIndex(state, *producer, targetClass.id)) { + FailureOr received = + materializePendingScalarReceive(state, targetClass, *pendingReceive, consumerInstance.op->getLoc()); + if (failed(received)) + return failure(); + return rejectNonLocalResolvedValue(*received); + } if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { size_t laneCount = targetClass.cpus.size(); @@ -3681,7 +7375,7 @@ FailureOr resolveInputValue(MaterializerState& state, appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); for (ProducerKey slotKey : slot.keys) state.availableValues.record(slotKey, targetClass.id, received); - return received; + return rejectNonLocalResolvedValue(received); } } @@ -3692,7 +7386,9 @@ FailureOr resolveInputValue(MaterializerState& state, consumerInstance.op->emitError("failed to materialize whole-batch input") << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; - return wholeBatch; + if (failed(wholeBatch)) + return failure(); + return rejectNonLocalResolvedValue(*wholeBatch); } consumerInstance.op->emitError("failed to resolve producer value") @@ -3701,6 +7397,15 @@ FailureOr resolveInputValue(MaterializerState& state, return failure(); } + if (isTensorValueDefinedInDifferentMaterializedClass(input, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic( + consumerInstance.op, + targetClass, + "input resolution tried to append a tensor from another materialized class as a normal input", + input); + return failure(); + } + return appendInput(state, targetClass, input); } @@ -3746,6 +7451,15 @@ LogicalResult mapInputs(MaterializerState& state, const ComputeInstance& instance, IRMapping& mapper, CloneIndexingContext indexing) { + auto mapResolvedInput = [&](Value resolved) -> FailureOr { + return materializeTensorValueForMaterializedClassUse( + state, + targetClass, + resolved, + targetClass.op, + "input mapping tried to reuse a tensor from another materialized class"); + }; + Operation* op = instance.op; if (auto compute = dyn_cast(op)) { for (auto [index, input] : llvm::enumerate(compute.getInputs())) { @@ -3762,7 +7476,16 @@ LogicalResult mapInputs(MaterializerState& state, auto inputArg = compute.getInputArgument(index); if (!inputArg) return compute.emitOpError("expected compute input block argument while materializing inputs"); - mapper.map(*inputArg, *mapped); + FailureOr remapped = mapResolvedInput(*mapped); + if (failed(remapped)) { + emitNonLocalMaterializedClassValueDiagnostic(compute, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); + return failure(); + } + mapper.map(*inputArg, *remapped); } return success(); } @@ -3778,7 +7501,16 @@ LogicalResult mapInputs(MaterializerState& state, auto inputArg = batch.getInputArgument(index); if (!inputArg) return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); - mapper.map(*inputArg, *mapped); + FailureOr remapped = mapResolvedInput(*mapped); + if (failed(remapped)) { + emitNonLocalMaterializedClassValueDiagnostic(batch, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); + return failure(); + } + mapper.map(*inputArg, *remapped); } return success(); } @@ -3867,11 +7599,35 @@ std::optional lookupProjectedExtractReplacement(Mat return classIt->second; } +bool requiresConstantProjectionSlotIndex(MaterializerState& state, + MaterializedClass& targetClass, + Operation* sourceOp) { + bool requiresConstantIndex = false; + sourceOp->walk([&](tensor::ExtractSliceOp extract) { + if (requiresConstantIndex) + return WalkResult::interrupt(); + + std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract); + if (!replacement) + return WalkResult::advance(); + + if (replacement->layout.payloadFragmentCount != replacement->layout.fragmentsPerLogicalSlot) { + requiresConstantIndex = true; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return requiresConstantIndex; +} + LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, MaterializedClass& targetClass, Operation& originalOp, Operation& clonedOp, - CloneIndexingContext indexing) { + CloneIndexingContext indexing, + IRMapping& mapper) { if (auto originalExtract = dyn_cast(&originalOp)) { if (std::optional replacement = lookupProjectedExtractReplacement(state, targetClass, originalExtract)) { @@ -3881,7 +7637,7 @@ LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& sta state.rewriter.setInsertionPoint(clonedExtract); FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex); + state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex, &mapper); if (failed(projected)) return failure(); @@ -3906,7 +7662,7 @@ LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& sta Operation& originalNestedOp = *originalIt++; Operation* currentClonedOp = &*clonedIt++; if (failed(applyProjectedExtractReplacementsInClonedOp( - state, targetClass, originalNestedOp, *currentClonedOp, indexing))) + state, targetClass, originalNestedOp, *currentClonedOp, indexing, mapper))) return failure(); } if (originalIt != originalBlock.end() || clonedIt != clonedBlock.end()) @@ -3917,6 +7673,40 @@ LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& sta return success(); } +LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& clonedOp, IRMapping& mapper) { + if (originalOp.getNumRegions() != clonedOp.getNumRegions()) + return clonedOp.emitError("cloned operation has a different number of regions than the source operation"); + + for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { + if (std::distance(originalRegion.begin(), originalRegion.end()) + != std::distance(clonedRegion.begin(), clonedRegion.end())) + return clonedOp.emitError("cloned operation has a different number of blocks than the source operation"); + + for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { + if (originalBlock.getNumArguments() != clonedBlock.getNumArguments()) + return clonedOp.emitError("cloned operation block has a different number of arguments than the source block"); + + for (auto [originalArg, clonedArg] : llvm::zip(originalBlock.getArguments(), clonedBlock.getArguments())) + if (!mapper.contains(originalArg)) + mapper.map(originalArg, clonedArg); + + if (std::distance(originalBlock.begin(), originalBlock.end()) != std::distance(clonedBlock.begin(), clonedBlock.end())) + return clonedOp.emitError("cloned operation block has a different number of operations than the source block"); + + auto originalIt = originalBlock.begin(); + auto clonedIt = clonedBlock.begin(); + while (originalIt != originalBlock.end()) { + Operation& originalNestedOp = *originalIt++; + Operation& clonedNestedOp = *clonedIt++; + if (failed(mapClonedRegionBlockArguments(originalNestedOp, clonedNestedOp, mapper))) + return failure(); + } + } + } + + return success(); +} + LogicalResult cloneComputeTemplateBody(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, @@ -3928,7 +7718,7 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, if (std::optional replacement = lookupProjectedExtractReplacement(state, targetClass, extract)) { FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, extract, *replacement, indexing.projectionSlotIndex); + state, targetClass, extract, *replacement, indexing.projectionSlotIndex, &mapper); if (failed(projected)) return failure(); @@ -3937,9 +7727,31 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, } } + for (Value operand : op.getOperands()) { + if (mapper.contains(operand)) + continue; + + FailureOr localized = localizeMaterializedClassOperand( + state, + targetClass, + operand, + &op, + "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", + "cloneComputeTemplateBody produced an unsupported external non-tensor operand", + &mapper); + if (failed(localized)) + return failure(); + if (*localized != operand) + mapper.map(operand, *localized); + } + Operation* cloned = state.rewriter.clone(op, mapper); + if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) + return failure(); + if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) + return failure(); if (op.getNumRegions() != 0 - && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing))) + && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing, mapper))) return failure(); for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(oldResult, newResult); @@ -3952,11 +7764,25 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state MaterializedClass& targetClass, tensor::ExtractSliceOp extract, const ProjectedExtractReplacement& replacement, - std::optional projectionSlotIndex) { + std::optional projectionSlotIndex, + IRMapping* mapper) { if (failed(verifyProjectedFragmentLayout(targetClass.op, replacement.layout))) return failure(); + + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + replacement.payload, + targetClass.op, + "projected extract replacement tried to reuse a tensor from another materialized class", + std::nullopt, + mapper); + if (failed(localizedPayload)) + return failure(); + Value payload = *localizedPayload; + if (replacement.layout.payloadFragmentCount == 1) - return replacement.payload; + return payload; if (replacement.layout.payloadFragmentCount < replacement.layout.fragmentsPerLogicalSlot) return targetClass.op->emitError("projected replacement payload is smaller than one logical slot"); @@ -3980,7 +7806,11 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state Value linearizedIndex = intraSlotFragmentIndex; for (auto [index, loop] : llvm::enumerate(surroundingLoops)) { - Value iv = loop.getInductionVar(); + FailureOr localizedIv = + rematerializeIndexValueInClass(state, targetClass, loop.getInductionVar(), extract.getLoc(), mapper); + if (failed(localizedIv)) + return failure(); + Value iv = *localizedIv; Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); @@ -4012,10 +7842,16 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state if (!projectionSlotIndex) return targetClass.op->emitError("packed projected extract replacement requires a fragment slot index"); + FailureOr localProjectionSlotIndex = + rematerializeIndexValueInClass(state, targetClass, *projectionSlotIndex, extract.getLoc(), mapper); + if (failed(localProjectionSlotIndex)) + return failure(); + Value fragmentsPerLogicalSlot = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); - Value base = arith::MulIOp::create(state.rewriter, extract.getLoc(), *projectionSlotIndex, fragmentsPerLogicalSlot) - .getResult(); + Value base = + arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot) + .getResult(); return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); }; @@ -4023,10 +7859,12 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state if (failed(packedFragmentIndex)) return failure(); - Value packedOffset = scaleIndexByDim0Size( - state, targetClass.op, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); - return createDim0ExtractSlice( - state, extract.getLoc(), replacement.payload, packedOffset, replacement.layout.fragmentType.getDimSize(0)); + FailureOr packedOffset = scaleIndexByDim0SizeInClass( + state, targetClass, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); + if (failed(packedOffset)) + return failure(); + return createDim0ExtractSliceInClass( + state, targetClass, extract.getLoc(), payload, *packedOffset, replacement.layout.fragmentType.getDimSize(0)); } FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, @@ -4050,6 +7888,102 @@ FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, .getOutput(); } +LogicalResult localizeCapturesInOperationTree(MaterializerState& state, + MaterializedClass& targetClass, + Operation& root, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper = nullptr) { + WalkResult walkResult = root.walk([&](Operation* nestedOp) -> WalkResult { + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value current = operand.get(); + if (isValueLegalInMaterializedClassBody(current, targetClass)) + continue; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(nestedOp); + FailureOr localized = + localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); + if (failed(localized)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() + << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() + << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) + << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; + diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; + attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR"); + attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands"); + attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain"); + attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return WalkResult::interrupt(); + } + operand.set(*localized); + } + return WalkResult::advance(); + }); + + return walkResult.wasInterrupted() ? failure() : success(); +} + +LogicalResult localizeCapturesInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& clonedOp, + IRMapping* mapper) { + return localizeCapturesInOperationTree( + state, + targetClass, + clonedOp, + "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", + "cloneComputeTemplateBody produced an unsupported external non-tensor operand", + mapper); +} + +LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass) { + SmallVector bodyOps; + for (Operation& op : *targetClass.body) + op.walk([&](Operation* nestedOp) { bodyOps.push_back(nestedOp); }); + + for (Operation* nestedOp : bodyOps) { + if (nestedOp->getBlock() == nullptr) + continue; + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value current = operand.get(); + if (isValueLegalInMaterializedClassBody(current, targetClass)) + continue; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(nestedOp); + FailureOr localized = localizeMaterializedClassOperand( + state, + targetClass, + current, + nestedOp, + "final scheduled body capture localization tried to reuse a tensor from another materialized class", + "final scheduled body capture localization found an unsupported external non-tensor operand"); + if (failed(localized)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() + << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() + << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) + << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; + diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; + attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); + } + operand.set(*localized); + } + } + + return success(); +} + FailureOr> cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers, @@ -4431,8 +8365,14 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta for (auto [outputIndex, output] : llvm::enumerate(*produced)) { auto fragmentType = cast(output.getType()); Value acc = iterArgs[outputIndex]; - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); - yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, fragmentType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, output, acc, *firstOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); } return success(); }); @@ -4572,6 +8512,20 @@ bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, return false; } +bool canRegisterDeferredLocalPackedRun(MaterializerState& state, ArrayRef run) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + for (Value input : getComputeInstanceInputs(peer)) { + std::optional producer = getInputRequestProducerKey(input, peer); + if (producer && isWholeBatchProducerKey(*producer)) + return false; + } + } + } + + return true; +} + void markMaterializationRunSlots(MaterializerState& state, ClassId classId, SlotId startSlot, @@ -4595,11 +8549,13 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, auto sourceBatch = cast(getMaterializationRunSourceOp(run)); SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = getMaterializationRunLoc(run); + bool canDeferLocalPackedRun = canRegisterDeferredLocalPackedRun(state, run); for (const OutputDestinationGroup& group : groups) { - if (run.size() > 1 && group.destinationClasses.empty() - && !hasMaterializationRunGroupLiveExternalUse(state, run, group) - && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group)) { + bool canUseLocalOnlyPackedRun = run.size() > 1 && group.destinationClasses.empty() + && !hasMaterializationRunGroupLiveExternalUse(state, run, group) + && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group); + if (canUseLocalOnlyPackedRun && canDeferLocalPackedRun) { for (size_t resultIndex : group.resultIndices) { if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); @@ -4630,12 +8586,30 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, continue; } + if (canUseLocalOnlyPackedRun) { + if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) + return failure(); + continue; + } + if (failed(emitPackedRunFanout(state, targetClass, group.destinationClasses, keys, packed, fragmentType, loc))) return failure(); if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) return failure(); + Value representativeOutput = firstOriginalOutputs[resultIndex]; + if (hasLiveExternalUseCached(state, representativeOutput) + && isProjectedTerminalBatchHostOutput(representativeOutput, state.oldComputeOps)) { + std::optional groupedHostPublication = + tryEmitScalarPackedProjectedHostPublication(state, targetClass, keys, packed, representativeOutput, loc); + if (groupedHostPublication) { + if (failed(*groupedHostPublication)) + return failure(); + continue; + } + } + auto rankedFragmentType = cast(fragmentType); for (auto [runIndex, slot] : llvm::enumerate(run)) { assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); @@ -4647,9 +8621,19 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, continue; state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value fragment = getPackedSliceForRunIndex(state, targetClass.op, packed, rankedFragmentType, runIndex, loc); + FailureOr fragment = + getPackedSliceForRunIndex(state, targetClass, packed, rankedFragmentType, runIndex, loc); + if (failed(fragment)) + return failure(); - if (failed(emitHostCommunication(state, targetClass, fragment, originalOutput))) + if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { + ProducerKey key {slot.peers.front(), resultIndex}; + if (failed(emitProjectedBatchHostFragment(state, targetClass, key, *fragment, originalOutput, loc))) + return failure(); + continue; + } + + if (failed(emitHostCommunication(state, targetClass, *fragment, originalOutput))) return failure(); } } @@ -4705,8 +8689,90 @@ bool canCompactBatchClassRun(MaterializerState& state, return true; } +LogicalResult registerMaterializedBatchRunHostOutputs(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const OutputDestinationGroup& group) { + ArrayRef originalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= originalOutputs.size()) + return targetClass.op->emitError("batch materialization host output index out of range"); + + Value originalOutput = originalOutputs[resultIndex]; + if (!hasLiveExternalUseCached(state, originalOutput)) + continue; + + auto resultIt = targetClass.hostOutputToResultIndex.find(originalOutput); + if (resultIt == targetClass.hostOutputToResultIndex.end()) + return targetClass.op->emitError("missing host result slot for materialized batch output"); + + state.hostReplacements[originalOutput] = targetClass.op->getResult(resultIt->second); + } + + return success(); +} + +LogicalResult verifyMaterializedHostOutputs(MaterializerState& state) { + for (SpatCompute compute : state.func.getOps()) { + auto yieldOp = dyn_cast_or_null(compute.getBody().front().getTerminator()); + if (!yieldOp) + return compute.emitOpError("expected spat.yield terminator in materialized compute"); + if (compute.getNumResults() != yieldOp.getNumOperands()) + return compute.emitOpError("materialized compute result count does not match spat.yield operand count"); + for (auto [result, yielded] : llvm::zip(compute.getResults(), yieldOp.getOperands())) + if (result.getType() != yielded.getType()) + return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand"); + } + + for (SpatChannelReceiveOp receive : state.func.getOps()) { + if (!receive.getOutput().use_empty()) + continue; + return receive.emitOpError("materialized channel_receive result must have at least one use"); + } + + for (const MaterializedClass& materializedClass : state.classes) { + if (!materializedClass.isBatch || materializedClass.hostOutputs.empty()) + continue; + + auto batch = dyn_cast(materializedClass.op); + auto inParallel = dyn_cast_or_null(materializedClass.body->getTerminator()); + if (!batch || !inParallel) + return materializedClass.op->emitError("expected resultful materialized compute_batch host owner"); + + for (Value hostOutput : materializedClass.hostOutputs) { + auto ownerIt = materializedClass.hostOutputToResultIndex.find(hostOutput); + if (ownerIt == materializedClass.hostOutputToResultIndex.end()) + return materializedClass.op->emitError("missing host result slot for materialized compute_batch host output"); + + auto outputArg = batch.getOutputArgument(ownerIt->second); + if (!outputArg) + return batch.emitOpError("missing output block argument for materialized compute_batch host output"); + + bool foundProjection = false; + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert || insert.getDest() != *outputArg) + continue; + foundProjection = true; + break; + } + + if (!foundProjection) + return batch.emitOpError( + "materialized terminal compute_batch host output is missing tensor.parallel_insert_slice publication"); + } + } + + for (const auto& [originalOutput, replacement] : state.hostReplacements) + if (originalOutput.getType() != replacement.getType()) + return replacement.getDefiningOp()->emitOpError("host output replacement type does not match original output type") + << " replacementType=" << replacement.getType() << " outputType=" << originalOutput.getType(); + + return success(); +} + Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { - auto batch = cast(targetClass.op); + auto batch = cast(targetClass.op); auto laneArg = batch.getLaneArgument(); assert(laneArg && "expected materialized compute_batch lane argument"); @@ -4911,6 +8977,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, auto sourceBatch = cast(run.front().peers.front().op); SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = sourceBatch.getLoc(); + bool constantProjectionSlotIndex = requiresConstantProjectionSlotIndex(state, targetClass, sourceBatch); for (const OutputDestinationGroup& group : groups) { SmallVector sendPlans; @@ -4921,17 +8988,15 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { - Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + if (constantProjectionSlotIndex) { + for (auto [slotIndex, slot] : llvm::enumerate(run)) { + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + Value slotIndexValue = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(slotIndex)); + Value sourceLane = getOrCreateIndexConstant(state.constantFolder, targetClass.op, slot.peers.front().laneStart); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndexValue, loc); FailureOr> produced = cloneBatchBodyForLane(state, @@ -4939,7 +9004,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, getScheduledChunkForLogicalInstance(state, run.front().peers.front()), sourceLane, group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); + CloneIndexingContext {.runSlotIndex = slotIndexValue, + .projectionSlotIndex = slotIndexValue}); if (failed(produced)) return failure(); @@ -4951,10 +9017,43 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); } - return success(); - }); - if (failed(loop)) - return failure(); + } + } else { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + getScheduledChunkForLogicalInstance(state, run.front().peers.front()), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); + if (failed(produced)) + return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); + if (resultIt == group.resultIndices.end()) + return failure(); + + size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); + appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); + } + return success(); + }); + if (failed(loop)) + return failure(); + } for (const BatchRunSendPlan& plan : sendPlans) { if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) @@ -4963,6 +9062,9 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) return failure(); } + + if (failed(registerMaterializedBatchRunHostOutputs(state, targetClass, run, group))) + return failure(); } return success(); @@ -5037,37 +9139,300 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, } FailureOr createReceiveConcatLoop(MaterializerState& state, - Operation* anchor, - Operation* insertionPoint, + MaterializedClass& targetClass, RankedTensorType concatType, RankedTensorType fragmentType, const MessageVector& messages, Location loc) { - assert(succeeded(messages.verify(anchor)) && "message metadata is inconsistent"); + assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); assert(!messages.empty() && "expected at least one receive"); + Operation* insertionPoint = targetClass.body->getTerminator(); state.rewriter.setInsertionPoint(insertionPoint); Value init = tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); return emitIndexedFragmentInsertLoop( state, - anchor, - insertionPoint, + targetClass, init, static_cast(messages.size()), [&](Value index) -> FailureOr { - Value channelId = createIndexedChannelId(state, anchor, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, anchor, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, anchor, messages, index, loc); + Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc); return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) .getOutput(); }, [&](Value index) -> FailureOr { - return scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); + return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc); }, loc); } + +std::optional getDirectCommunicationOrderKey(Operation* op) { + if (!op) + return std::nullopt; + + Value channelId; + Value sourceCoreId; + Value targetCoreId; + if (auto send = dyn_cast(op)) { + channelId = send.getChannelId(); + sourceCoreId = send.getSourceCoreId(); + targetCoreId = send.getTargetCoreId(); + } + else if (auto receive = dyn_cast(op)) { + channelId = receive.getChannelId(); + sourceCoreId = receive.getSourceCoreId(); + targetCoreId = receive.getTargetCoreId(); + } + else { + return std::nullopt; + } + + auto channel = getConstantIndexValue(channelId); + auto source = getConstantIndexValue(sourceCoreId); + auto target = getConstantIndexValue(targetCoreId); + if (!channel || !source || !target) + return std::nullopt; + + return computeBlockingCommunicationOrderKey( + static_cast(*source), static_cast(*target), *channel); +} + +std::optional getScalarCommunicationOrderKey(Operation* op) { + if (!op) + return std::nullopt; + if (auto order = op->getAttrOfType(kRaptorCommOrderAttr)) + return order.getInt(); + if (auto directOrder = getDirectCommunicationOrderKey(op)) + return directOrder; + if (auto channel = op->getAttrOfType(kRaptorMinChannelIdAttr)) + return channel.getInt(); + return std::nullopt; +} + +bool isReorderableScalarCommunication(Operation* op) { + if (!getScalarCommunicationOrderKey(op).has_value()) + return false; + + // The global-order repair is intentionally conservative: it may reorder + // send-side projections, but it must not move receives or any other + // communication op that defines SSA values. Moving a receive after one of + // its users breaks MLIR dominance; moving it before the source can produce + // the payload can also create a receive/receive deadlock. Receives therefore + // have to be placed correctly by the materializer when they are created. + // Direct spat.channel_send operations are included even when they were not + // produced by appendScalarSendLoop and therefore do not carry raptor.* + // attributes yet. This is needed for large scalar-to-scalar payload transfers + // that must be hoisted before reciprocal receives. + return isa(op) || (op->getNumResults() == 0 && op->hasAttr(kRaptorMinChannelIdAttr)); +} + +Operation* getLaterOperationInBlock(Operation* lhs, Operation* rhs) { + if (!lhs) + return rhs; + if (!rhs) + return lhs; + return lhs->isBeforeInBlock(rhs) ? rhs : lhs; +} + +Operation* getNextInsertionPointAfter(Operation* op, Block& block) { + if (!op) + return &block.front(); + Operation* next = op->getNextNode(); + return next ? next : block.getTerminator(); +} + +bool hasConstantRoutingOperands(SpatChannelSendOp send) { + return getConstantIndexValue(send.getChannelId()).has_value() + && getConstantIndexValue(send.getSourceCoreId()).has_value() + && getConstantIndexValue(send.getTargetCoreId()).has_value(); +} + +Operation* getLatestSameBlockOperandDefinition(Operation* root, Block& block) { + Operation* latest = nullptr; + + auto consider = [&](Value value) { + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != &block || definingOp == root) + return; + latest = getLaterOperationInBlock(latest, definingOp); + }; + + // For direct sends with constant routing operands, only the payload is a real + // scheduling dependency. The channel/source/target constants can be + // rematerialized at the new insertion point. Treating those constants as hard + // dependencies prevents the repair from hoisting a ready send above an early + // receive, which is exactly the receive/receive deadlock pattern reported by + // the static communication checker. + if (auto send = dyn_cast(root)) { + if (hasConstantRoutingOperands(send)) { + consider(send.getInput()); + return latest; + } + } + + for (Value operand : root->getOperands()) + consider(operand); + + for (Region& region : root->getRegions()) { + region.walk([&](Operation* nested) { + if (nested == root) + return; + for (Value operand : nested->getOperands()) + consider(operand); + }); + } + + return latest; +} + +void rematerializeDirectSendRoutingConstantsAt(MaterializerState& state, + SpatChannelSendOp send, + Operation* insertionPoint) { + if (!send || !insertionPoint || !hasConstantRoutingOperands(send)) + return; + + auto channel = getConstantIndexValue(send.getChannelId()); + auto source = getConstantIndexValue(send.getSourceCoreId()); + auto target = getConstantIndexValue(send.getTargetCoreId()); + if (!channel || !source || !target) + return; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(insertionPoint); + Location loc = send.getLoc(); + Value newChannel = arith::ConstantIndexOp::create(state.rewriter, loc, *channel); + Value newSource = arith::ConstantIndexOp::create(state.rewriter, loc, *source); + Value newTarget = arith::ConstantIndexOp::create(state.rewriter, loc, *target); + send->setOperand(0, newChannel); + send->setOperand(1, newSource); + send->setOperand(2, newTarget); +} + +LogicalResult reorderScalarClassCommunicationByGlobalOrder(MaterializerState& state, + MaterializedClass& materializedClass) { + if (materializedClass.isBatch) + return success(); + + Block& block = *materializedClass.body; + Operation* terminator = block.getTerminator(); + SmallVector communicationOps; + for (Operation& op : block) { + if (&op == terminator) + break; + if (isReorderableScalarCommunication(&op)) + communicationOps.push_back(&op); + } + + if (communicationOps.size() < 2) + return success(); + + llvm::stable_sort(communicationOps, [](Operation* lhs, Operation* rhs) { + std::optional lhsOrder = getScalarCommunicationOrderKey(lhs); + std::optional rhsOrder = getScalarCommunicationOrderKey(rhs); + if (lhsOrder != rhsOrder) + return lhsOrder.value_or(std::numeric_limits::max()) + < rhsOrder.value_or(std::numeric_limits::max()); + return lhs->isBeforeInBlock(rhs); + }); + + Operation* lastPlacedCommunication = nullptr; + for (Operation* communication : communicationOps) { + if (communication->getBlock() != &block) + return materializedClass.op->emitError("scalar communication global-order repair saw a moved operation"); + + Operation* dependency = getLatestSameBlockOperandDefinition(communication, block); + Operation* anchor = getLaterOperationInBlock(lastPlacedCommunication, dependency); + Operation* insertionPoint = getNextInsertionPointAfter(anchor, block); + + if (insertionPoint != communication && communication->getNextNode() != insertionPoint) { + if (auto send = dyn_cast(communication)) + rematerializeDirectSendRoutingConstantsAt(state, send, insertionPoint); + communication->moveBefore(insertionPoint); + } + + lastPlacedCommunication = communication; + } + + return success(); +} + +LogicalResult reorderScalarCommunicationsByGlobalOrder(MaterializerState& state) { + for (MaterializedClass& materializedClass : state.classes) + if (failed(reorderScalarClassCommunicationByGlobalOrder(state, materializedClass))) + return failure(); + return success(); +} + + +Operation* getEarliestOperationInBlock(Operation* lhs, Operation* rhs) { + if (!lhs) + return rhs; + if (!rhs) + return lhs; + return lhs->isBeforeInBlock(rhs) ? lhs : rhs; +} + +Operation* getTopLevelOperationInBlock(Operation* op, Block& block) { + for (Operation* current = op; current; current = current->getParentOp()) { + if (current->getBlock() == &block) + return current; + } + return nullptr; +} + +Operation* findEarliestTopLevelUse(Operation* producer, Block& block) { + Operation* earliest = nullptr; + for (Value result : producer->getResults()) { + for (Operation* user : result.getUsers()) { + Operation* topLevelUser = getTopLevelOperationInBlock(user, block); + if (!topLevelUser || topLevelUser == producer) + continue; + earliest = getEarliestOperationInBlock(earliest, topLevelUser); + } + } + return earliest; +} + +LogicalResult sinkScalarReceivesToFirstUse(MaterializerState& state) { + for (MaterializedClass& materializedClass : state.classes) { + if (materializedClass.isBatch) + continue; + + Block& block = *materializedClass.body; + Operation* terminator = block.getTerminator(); + SmallVector receives; + for (Operation& op : block) { + if (&op == terminator) + break; + if (isa(&op)) + receives.push_back(&op); + } + + for (Operation* receive : receives) { + if (receive->getBlock() != &block) + continue; + + Operation* firstUse = findEarliestTopLevelUse(receive, block); + if (!firstUse || firstUse == receive || firstUse->getBlock() != &block) + continue; + + if (!receive->isBeforeInBlock(firstUse)) + continue; + + if (receive->getNextNode() == firstUse) + continue; + + receive->setAttr("raptor.receive_sunk_to_first_use", UnitAttr::get(receive->getContext())); + receive->moveBefore(firstUse); + } + } + return success(); +} + void replaceHostUses(MaterializerState& state) { for (const auto& [oldValue, replacement] : state.hostReplacements) replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); @@ -5114,6 +9479,23 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch if (failed(materializeInstanceSlot(state, instance))) return failure(); + for (MaterializedClass& materializedClass : state.classes) + if (failed(localizeAllScheduledBodyCaptures(state, materializedClass))) + return failure(); + + if (failed(flushPendingProjectedHostReceives(state))) + return failure(); + + if (pimMaterializeScalarFanoutGlobalOrder) { + if (failed(sinkScalarReceivesToFirstUse(state))) + return failure(); + if (failed(reorderScalarCommunicationsByGlobalOrder(state))) + return failure(); + } + + if (failed(verifyMaterializedHostOutputs(state))) + return failure(); + replaceHostUses(state); if (failed(eraseOldComputeOps(state))) return failure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 1747a6d..6f718bd 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -40,10 +40,10 @@ using namespace mlir; namespace onnx_mlir { namespace { using namespace onnx_mlir::compact_asm; -using SpatCompute = spatial::SpatCompute; -using SpatComputeBatch = spatial::SpatComputeBatch; +using SpatCompute = spatial::SpatGraphCompute; +using SpatComputeBatch = spatial::SpatGraphComputeBatch; -static std::optional getComputeCoreId(SpatCompute compute) { +static std::optional getComputeCoreId(spatial::SpatScheduledCompute compute) { if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id"); if (failed(checkedCoreId)) @@ -209,7 +209,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu }; for (Operation& op : funcOp.getBody().front()) { - if (auto spatCompute = dyn_cast(&op)) { + if (auto spatCompute = dyn_cast(&op)) { uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation()); SmallVector coreIds; @@ -229,7 +229,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu totalCrossbarCount += perInstanceCrossbarCount; continue; } - if (auto batch = dyn_cast(&op)) { + if (auto batch = dyn_cast(&op)) { uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody()); uint64_t logicalCount = static_cast(batch.getLaneCount()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation()); @@ -353,7 +353,17 @@ public: void runOnOperation() override { func::FuncOp func = getOperation(); + if (failed(verifyLogicalSpatialGraphInvariants(func))) { + func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes"); + signalPassFailure(); + return; + } mergeTriviallyConnectedComputes(func); + if (failed(verifyLogicalSpatialGraphInvariants(func))) { + func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification"); + signalPassFailure(); + return; + } const spatial::MergeScheduleResult* analysisResult = nullptr; analysisResult = &getAnalysis().getResult(); @@ -367,8 +377,8 @@ public: signalPassFailure(); return; } - if (failed(verifySpatialCommunicationInvariants(func))) { - func.emitOpError("merged Spatial communication invariant verification failed"); + if (failed(verifyScheduledSpatialInvariants(func))) { + func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization"); signalPassFailure(); return; } diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index 515f459..5e03534 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -8,6 +8,8 @@ namespace onnx_mlir { std::unique_ptr createONNXToSpatialPass(); +std::unique_ptr createSpatialLayoutPlanningPass(); +std::unique_ptr createLowerSpatialPlansPass(); std::unique_ptr createSpatialToGraphvizPass(); diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 5afc09a..eaa44b2 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -72,6 +72,8 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const { void PimAccelerator::registerPasses(int optLevel) const { LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); registerPass(createONNXToSpatialPass); + registerPass(createSpatialLayoutPlanningPass); + registerPass(createLowerSpatialPlansPass); registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPimPass); registerPass(createPimBufferizationPass);