diff --git a/AGENTS.md b/AGENTS.md index 9cadae5..e45d878 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,6 +6,7 @@ * Always try the release build first before building with the debug version * Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively * The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations +* Always prepend rtk to shell commands if missing and if rtk is available # Core engineering philosophy diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp index b6b5162..e05bdff 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp @@ -113,6 +113,19 @@ void verifyScheduledInputs(ComputeOpTy compute, } } +template +void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute, + pim::CappedDiagnosticReporter& diagnostics) { + compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) { + std::optional mode = reconciliator.getMode(); + if (!mode || *mode != "fragment_assembly") + return; + diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) { + illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization"); + }); + }); +} + void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) { for (Operation& op : funcOp.getOps()) { if (isa()) + for (auto compute : funcOp.getOps()) { verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics); - for (auto batch : funcOp.getOps()) + verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics); + } + for (auto batch : funcOp.getOps()) { verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics); + verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics); + } if (failed(verifyNoComputeBodyCaptures(funcOp))) return failure(); diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed"); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 14c2b80..2709f1e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1702,8 +1702,6 @@ static bool canUseStructuredRewrite(const ConvLoweringState& state) { state.outWidth); if (!tiling) return false; - if (tiling->numChannelTiles > static_cast(crossbarCountInCore.getValue())) - return false; if (!state.hasBias) return true; diff --git a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp index 9c300dc..8845f76 100644 --- a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp @@ -107,6 +107,7 @@ static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewr nullptr, nullptr, nullptr, + nullptr, nullptr); } diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 4a3b01b..8aea396 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -131,6 +131,151 @@ static FailureOr getDirectReturnOperandIndex(OpResult result) { return result.getUses().begin()->getOperandNumber(); } +struct BatchFragmentAssemblyPlan { + unsigned returnIndex = 0; + int64_t localSourceElementOffset = 0; + int64_t fragmentByteSize = 0; + SmallVector hostOffsetsByLane; +}; + +static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef values, Location loc) { + assert(!values.empty() && "expected lane-indexed values"); + if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); })) + return getOrCreateIndexConstant(rewriter, anchor, values.front()); + + if (values.size() >= 2) { + int64_t step = values[1] - values[0]; + bool arithmetic = llvm::all_of(llvm::seq(2, values.size()), [&](size_t index) { + return values[index] == values.front() + static_cast(index) * step; + }); + if (arithmetic) { + Value base = getOrCreateIndexConstant(rewriter, anchor, values.front()); + if (step == 0) + return base; + Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step); + Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult(); + return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult(); + } + } + + Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front()); + for (auto [lane, value] : llvm::enumerate(values.drop_front())) { + Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast(lane + 1)); + Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue); + Value candidate = getOrCreateIndexConstant(rewriter, anchor, value); + selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected); + } + return selected; +} + +static FailureOr> +analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) { + SmallVector plans; + if (!packedResultType.hasStaticShape() || laneCount == 0) + return failure(); + + int64_t packedElementCount = packedResultType.getNumElements(); + if (packedElementCount % static_cast(laneCount) != 0) + return failure(); + int64_t payloadElementCount = packedElementCount / static_cast(laneCount); + size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType()); + + for (OpOperand& use : result.getUses()) { + auto reconciliator = dyn_cast(use.getOwner()); + if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType()) + return failure(); + std::optional mode = reconciliator.getMode(); + std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); + std::optional> stridesAttr = reconciliator.getFragmentStrides(); + if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) + return failure(); + if (!reconciliator.getOutput().hasOneUse() || !isa(*reconciliator.getOutput().getUsers().begin())) + return failure(); + + unsigned returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber(); + auto hostResultType = dyn_cast(reconciliator.getOutput().getType()); + if (!hostResultType || !hostResultType.hasStaticShape()) + return failure(); + + ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef sourceOffsets = *sourceOffsetsAttr; + ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); + ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatStrides = *stridesAttr; + int64_t rank = hostResultType.getRank(); + SmallVector fragmentOperands {reconciliator.getInput()}; + llvm::append_range(fragmentOperands, reconciliator.getFragments()); + if (failed(validateFragmentAssemblyMetadata(reconciliator, + rank, + fragmentOperands.size(), + operandIndices, + sourceOffsets, + flatOffsets, + flatSizes, + flatStrides))) + return failure(); + for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { + if (operandIndices[fragmentIndex] != static_cast(use.getOperandNumber())) + continue; + + int64_t sourceElementOffset = sourceOffsets[fragmentIndex]; + int64_t lane = sourceElementOffset / payloadElementCount; + int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount; + if (lane < 0 || lane >= static_cast(laneCount)) + return failure(); + + SmallVector fragmentOffsets; + SmallVector fragmentSizes; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + if (flatStrides[flatIndex] != 1) + return failure(); + fragmentOffsets.push_back(flatOffsets[flatIndex]); + fragmentSizes.push_back(flatSizes[flatIndex]); + } + + if (failed(forEachContiguousDestinationChunk( + hostResultType.getShape(), + fragmentOffsets, + fragmentSizes, + [&](ArrayRef chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { + int64_t hostElementOffset = 0; + SmallVector hostStrides = computeRowMajorStrides(hostResultType.getShape()); + for (auto [dim, offset] : llvm::enumerate(chunkOffsets)) + hostElementOffset += offset * hostStrides[dim]; + int64_t hostByteOffset = hostElementOffset * static_cast(elementSize); + int64_t fragmentByteSize = chunkElements * static_cast(elementSize); + int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset; + + auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) { + return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset + && plan.fragmentByteSize == fragmentByteSize; + }); + if (planIt == plans.end()) { + BatchFragmentAssemblyPlan plan; + plan.returnIndex = returnIndex; + plan.localSourceElementOffset = chunkSourceOffset; + plan.fragmentByteSize = fragmentByteSize; + plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits::min()); + plan.hostOffsetsByLane[static_cast(lane)] = hostByteOffset; + plans.push_back(std::move(plan)); + return success(); + } + + planIt->hostOffsetsByLane[static_cast(lane)] = hostByteOffset; + return success(); + }))) + return failure(); + } + } + + for (const BatchFragmentAssemblyPlan& plan : plans) + if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits::min(); })) + return failure(); + return plans; +} + static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) { if (scale == 1) return base; @@ -250,6 +395,10 @@ static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, "fragment assembly lowering requires explicit operand indices and unit strides"); ArrayRef operandIndices = *operandIndicesAttr; + std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); + if (!sourceOffsetsAttr) + return reconciliator.emitOpError("fragment assembly lowering requires explicit source offsets"); + ArrayRef sourceOffsets = *sourceOffsetsAttr; ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); ArrayRef flatSizes = reconciliator.getFragmentSizes(); ArrayRef flatStrides = *fragmentStridesAttr; @@ -257,21 +406,25 @@ static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, SmallVector fragmentOperands {reconciliator.getInput()}; llvm::append_range(fragmentOperands, reconciliator.getFragments()); + if (failed(validateFragmentAssemblyMetadata(reconciliator, + rank, + fragmentOperands.size(), + operandIndices, + sourceOffsets, + flatOffsets, + flatSizes, + flatStrides))) + return failure(); - DenseMap packedFragmentOrdinals; for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; - if (operandIndex < 0 || operandIndex >= static_cast(fragmentOperands.size())) - return reconciliator.emitOpError("fragment assembly operand index is out of range"); SmallVector fragmentOffsets; - int64_t fragmentElements = 1; for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) return reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); fragmentOffsets.push_back(flatOffsets[flatIndex]); - fragmentElements *= flatSizes[flatIndex]; } Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]); @@ -279,20 +432,21 @@ static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, if (!sourceType || !sourceType.hasStaticShape()) return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands"); - int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++; SmallVector fragmentShape; fragmentShape.reserve(rank); for (int64_t dim = 0; dim < rank; ++dim) fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]); Value fragment = source; - if (llvm::to_vector(sourceType.getShape()) != fragmentShape) { - SmallVector extractOffsets(rank, 0); - extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0]; - fragment = tensor::ExtractSliceOp::create(rewriter, - reconciliator.getLoc(), + if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) { + FailureOr> extractOffsets = getStaticSliceOffsetsForElementOffset( + reconciliator, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice"); + if (failed(extractOffsets)) + return failure(); + fragment = tensor::ExtractSliceOp::create(rewriter, + reconciliator.getLoc(), source, - getStaticIndexAttrs(rewriter, extractOffsets), + getStaticIndexAttrs(rewriter, *extractOffsets), getStaticIndexAttrs(rewriter, fragmentShape), getUnitStrides(rewriter, rank)); } @@ -351,16 +505,29 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds)); SmallVector returnOperandIndices; + SmallVector, 4> fragmentAssemblyPlansByResult; if (computeBatchOp.getNumResults() != 0) { returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits::max()); + fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults()); for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) { if (result.use_empty()) continue; FailureOr returnOperandIndex = getDirectReturnOperandIndex(cast(result)); - if (failed(returnOperandIndex)) + if (succeeded(returnOperandIndex)) { + returnOperandIndices[resultIndex] = *returnOperandIndex; + continue; + } + + auto resultType = dyn_cast(result.getType()); + if (!resultType || !resultType.hasStaticShape()) + return computeBatchOp.emitOpError( + "resultful compute_batch publication lowering requires static ranked tensor results"); + FailureOr> fragmentAssemblyPlans = + analyzeTopLevelFragmentAssemblyUses(cast(result), resultType, computeBatchOp.getLaneCount()); + if (failed(fragmentAssemblyPlans)) return computeBatchOp.emitOpError( "resultful compute_batch lowering currently requires each result to be used directly by func.return"); - returnOperandIndices[resultIndex] = *returnOperandIndex; + fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end()); } } @@ -446,9 +613,44 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); if (resultIndex >= returnOperandIndices.size()) return insertSlice.emitOpError("result index out of range while lowering host batch output"); - if (returnOperandIndices[resultIndex] == std::numeric_limits::max()) + bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits::max(); + bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size() + && !fragmentAssemblyPlansByResult[resultIndex].empty(); + if (!hasDirectReturn && !hasFragmentAssembly) continue; + Value mappedSource = mapper.lookup(insertSlice.getSource()); + + if (hasFragmentAssembly) { + BlockArgument laneArg = coreBatchOp.getLaneArgument(); + auto mappedSourceType = dyn_cast(mappedSource.getType()); + if (!mappedSourceType || !mappedSourceType.hasStaticShape()) + return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source"); + for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) { + Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc()); + auto sizeAttr = pim::getCheckedI32Attr( + rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size"); + if (failed(sizeAttr)) + return failure(); + Value hostTargetOffset = + createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc()); + Value deviceSourceOffset = getOrCreateIndexConstant( + rewriter, coreBatchOp.getOperation(), + plan.localSourceElementOffset * static_cast(getElementTypeSizeInBytes(mappedSourceType.getElementType()))); + outputTensor = + pim::PimMemCopyDevToHostOp::create(rewriter, + insertSlice.getLoc(), + outputTensor.getType(), + hostTargetOffset, + deviceSourceOffset, + outputTensor, + mappedSource, + *sizeAttr) + .getOutput(); + } + continue; + } + Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc()); auto hostTargetType = cast(hostTarget.getType()); if (auto reconciliator = @@ -467,7 +669,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul } } - Value mappedSource = mapper.lookup(insertSlice.getSource()); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource); diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 72480c1..970063c 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -5,6 +5,7 @@ #include #include "Common.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" using namespace llvm; @@ -72,4 +73,117 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); } +LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator, + int64_t resultRank, + size_t operandCount, + ArrayRef operandIndices, + ArrayRef sourceOffsets, + ArrayRef flatOffsets, + ArrayRef flatSizes, + ArrayRef flatStrides) { + if (operandIndices.size() != sourceOffsets.size()) + return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match"); + if (flatOffsets.size() != flatSizes.size()) + return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths"); + if (flatStrides.size() != flatOffsets.size()) + return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths"); + if (flatOffsets.size() != operandIndices.size() * static_cast(resultRank)) + return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment"); + + for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) { + if (operandIndex < 0 || operandIndex >= static_cast(operandCount)) + return reconciliator.emitOpError("fragment assembly operand index is out of range"); + if (sourceOffsets[fragmentIndex] < 0) + return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative"); + } + + return success(); +} + +static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { + SmallVector indices(shape.size(), 0); + for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { + indices[dim] = flatIndex % shape[dim]; + flatIndex /= shape[dim]; + } + return indices; +} + +FailureOr> +getStaticSliceOffsetsForElementOffset(Operation* anchor, + ShapedType sourceType, + ArrayRef fragmentShape, + int64_t sourceElementOffset, + StringRef fieldName) { + if (!sourceType.hasStaticShape()) + return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure(); + if (sourceElementOffset < 0) + return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure(); + if (sourceType.getRank() != static_cast(fragmentShape.size())) + return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure(); + + int64_t sourceElementCount = sourceType.getNumElements(); + int64_t fragmentElementCount = 1; + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + if (fragmentShape[dim] < 0) + return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure(); + fragmentElementCount *= fragmentShape[dim]; + } + if (sourceElementOffset + fragmentElementCount > sourceElementCount) + return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure(); + + SmallVector sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape()); + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim)) + return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure(); + } + return sliceOffsets; +} + +LogicalResult +forEachContiguousDestinationChunk(ArrayRef destShape, + ArrayRef baseOffsets, + ArrayRef sizes, + llvm::function_ref, int64_t, int64_t)> callback) { + int64_t rank = static_cast(sizes.size()); + int64_t suffixStart = rank - 1; + while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart]) + --suffixStart; + if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0) + suffixStart = 0; + else + ++suffixStart; + + int64_t chunkElements = 1; + for (int64_t dim = suffixStart; dim < rank; ++dim) + chunkElements *= sizes[dim]; + + SmallVector prefixExtents(sizes.begin(), sizes.begin() + suffixStart); + SmallVector current(prefixExtents.size(), 0); + int64_t sourceChunkOrdinal = 0; + + auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult { + if (dim == static_cast(prefixExtents.size())) { + SmallVector chunkOffsets(baseOffsets.begin(), baseOffsets.end()); + for (int64_t prefixDim = 0; prefixDim < static_cast(current.size()); ++prefixDim) + chunkOffsets[prefixDim] += current[prefixDim]; + if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements))) + return failure(); + ++sourceChunkOrdinal; + return success(); + } + + for (int64_t index = 0; index < prefixExtents[dim]; ++index) { + current[dim] = index; + if (failed(visit(visit, dim + 1))) + return failure(); + } + return success(); + }; + + if (prefixExtents.empty()) + return callback(baseOffsets, 0, chunkElements); + return visit(visit, 0); +} + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index 49fe3ec..aa1c7cb 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -1,10 +1,17 @@ #pragma once +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" + #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +namespace onnx_mlir::spatial { +class SpatReconciliatorOp; +} + namespace onnx_mlir { mlir::FailureOr @@ -29,6 +36,29 @@ mlir::SmallVector getOpOperandsSortedByUses(mlir::Operation* operat mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation); +mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatReconciliatorOp reconciliator, + int64_t resultRank, + size_t operandCount, + llvm::ArrayRef operandIndices, + llvm::ArrayRef sourceOffsets, + llvm::ArrayRef flatOffsets, + llvm::ArrayRef flatSizes, + llvm::ArrayRef flatStrides); + +mlir::FailureOr> +getStaticSliceOffsetsForElementOffset(mlir::Operation* anchor, + mlir::ShapedType sourceType, + llvm::ArrayRef fragmentShape, + int64_t sourceElementOffset, + llvm::StringRef fieldName); + +mlir::LogicalResult +forEachContiguousDestinationChunk(llvm::ArrayRef destShape, + llvm::ArrayRef baseOffsets, + llvm::ArrayRef sizes, + llvm::function_ref, int64_t, int64_t)> + callback); + inline mlir::tensor::EmptyOp createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType()); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 366a259..ca8c844 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -52,11 +52,14 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, std::optional modeAttr = reconciliator.getMode(); std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); std::optional> fragmentStridesAttr = reconciliator.getFragmentStrides(); - if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr) + if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr + || !fragmentStridesAttr) return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata"); ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef sourceOffsets = *sourceOffsetsAttr; ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); ArrayRef flatSizes = reconciliator.getFragmentSizes(); ArrayRef flatStrides = *fragmentStridesAttr; @@ -64,13 +67,19 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, SmallVector fragmentOperands {reconciliator.getInput()}; llvm::append_range(fragmentOperands, reconciliator.getFragments()); + if (failed(validateFragmentAssemblyMetadata(reconciliator, + rank, + fragmentOperands.size(), + operandIndices, + sourceOffsets, + flatOffsets, + flatSizes, + flatStrides))) + return failure(); Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType); - DenseMap packedFragmentOrdinals; for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; - if (operandIndex < 0 || operandIndex >= static_cast(fragmentOperands.size())) - return reconciliator.emitOpError("fragment assembly operand index is out of range"); SmallVector fragmentOffsets; int64_t fragmentElements = 1; @@ -96,11 +105,16 @@ static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, if (failed(sizeAttr)) return failure(); - int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++; Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets); + auto deviceSourceOffsetBytes = pim::checkedMul(static_cast(sourceOffsets[fragmentIndex]), + static_cast(getElementTypeSizeInBytes(sourceType.getElementType())), + reconciliator, + "fragment assembly device source offset"); + if (failed(deviceSourceOffsetBytes)) + return failure(); Value deviceSourceOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), - packedFragmentOrdinal * fragmentBytes); + static_cast(*deviceSourceOffsetBytes)); currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter, reconciliator.getLoc(), currentOutput.getType(), diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index d5a57a1..b675ff6 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -47,11 +47,13 @@ struct LowerFragmentAssemblyReconciliatorPattern return op.emitOpError("fragment assembly lowering requires a static ranked tensor result"); std::optional> operandIndicesAttr = op.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = op.getFragmentSourceOffsets(); std::optional> fragmentStridesAttr = op.getFragmentStrides(); - if (!operandIndicesAttr || !fragmentStridesAttr) + if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr) return op.emitOpError("fragment assembly lowering requires explicit fragment metadata"); ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef sourceOffsets = *sourceOffsetsAttr; ArrayRef flatOffsets = op.getFragmentOffsets(); ArrayRef flatSizes = op.getFragmentSizes(); ArrayRef flatStrides = *fragmentStridesAttr; @@ -59,23 +61,21 @@ struct LowerFragmentAssemblyReconciliatorPattern SmallVector fragmentOperands {adaptor.getInput()}; llvm::append_range(fragmentOperands, adaptor.getFragments()); + if (failed(validateFragmentAssemblyMetadata( + op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides))) + return failure(); Value currentOutput = tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult(); - DenseMap packedFragmentOrdinals; for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; - if (operandIndex < 0 || operandIndex >= static_cast(fragmentOperands.size())) - return op.emitOpError("fragment assembly operand index is out of range"); SmallVector fragmentOffsets; - int64_t fragmentElements = 1; for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) return op.emitOpError("fragment assembly lowering only supports unit strides"); fragmentOffsets.push_back(flatOffsets[flatIndex]); - fragmentElements *= flatSizes[flatIndex]; } Value source = fragmentOperands[operandIndex]; @@ -83,20 +83,21 @@ struct LowerFragmentAssemblyReconciliatorPattern if (!sourceType || !sourceType.hasStaticShape()) return op.emitOpError("fragment assembly lowering requires static ranked tensor operands"); - int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++; SmallVector fragmentShape; fragmentShape.reserve(rank); for (int64_t dim = 0; dim < rank; ++dim) fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]); Value fragment = source; - if (llvm::to_vector(sourceType.getShape()) != fragmentShape) { - SmallVector extractOffsets(rank, 0); - extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0]; + if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) { + FailureOr> extractOffsets = getStaticSliceOffsetsForElementOffset( + op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice"); + if (failed(extractOffsets)) + return failure(); fragment = tensor::ExtractSliceOp::create(rewriter, op.getLoc(), source, - getStaticIndexAttrs(rewriter, extractOffsets), + getStaticIndexAttrs(rewriter, *extractOffsets), getStaticIndexAttrs(rewriter, fragmentShape), getUnitStrides(rewriter, rank)); } diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index ef23565..3a83b08 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -149,6 +149,40 @@ static std::optional analyzeReturnUse(Value value) { }; } +static FailureOr, 4>> +analyzeTopLevelFragmentAssemblyUses(Value value) { + SmallVector, 4> uses; + for (OpOperand& use : value.getUses()) { + auto reconciliator = dyn_cast(use.getOwner()); + if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType()) + return failure(); + std::optional mode = reconciliator.getMode(); + if (!mode || *mode != "fragment_assembly") + return failure(); + if (!reconciliator.getOutput().hasOneUse() || !isa(*reconciliator.getOutput().getUsers().begin())) + return failure(); + std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); + std::optional> stridesAttr = reconciliator.getFragmentStrides(); + auto resultType = dyn_cast(reconciliator.getOutput().getType()); + if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape()) + return failure(); + SmallVector fragmentOperands {reconciliator.getInput()}; + llvm::append_range(fragmentOperands, reconciliator.getFragments()); + if (failed(validateFragmentAssemblyMetadata(reconciliator, + resultType.getRank(), + fragmentOperands.size(), + *operandIndicesAttr, + *sourceOffsetsAttr, + reconciliator.getFragmentOffsets(), + reconciliator.getFragmentSizes(), + *stridesAttr))) + return failure(); + uses.emplace_back(reconciliator, use.getOperandNumber()); + } + return uses; +} + static std::optional analyzeConcatReturnUse(Value value) { auto getConcatResult = [](Operation* op) -> Value { if (auto tensorConcat = dyn_cast(op)) @@ -559,6 +593,116 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low } } + FailureOr, 4>> fragmentAssemblyUses = + analyzeTopLevelFragmentAssemblyUses(producedValue); + if (succeeded(fragmentAssemblyUses)) { + auto sourceType = dyn_cast(storedValue.getType()); + if (!sourceType || !sourceType.hasStaticShape()) { + producerOp->emitOpError("fragment assembly publication requires a static ranked tensor source"); + return ReturnPathLoweringResult::Failure; + } + + size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType()); + for (auto [reconciliator, operandNumber] : *fragmentAssemblyUses) { + rewriter.setInsertionPointAfterValue(storedValue); + std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); + std::optional> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets(); + std::optional> stridesAttr = reconciliator.getFragmentStrides(); + if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) { + reconciliator.emitOpError( + "fragment assembly lowering requires explicit operand, source-offset, and stride metadata"); + return ReturnPathLoweringResult::Failure; + } + + size_t returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber(); + Value outputTensor = outputTensors[returnIndex](rewriter, loc); + auto outputType = dyn_cast(outputTensor.getType()); + auto resultType = dyn_cast(reconciliator.getOutput().getType()); + if (!outputType || !resultType || !resultType.hasStaticShape()) { + reconciliator.emitOpError("fragment assembly lowering requires static ranked host outputs"); + return ReturnPathLoweringResult::Failure; + } + + ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef sourceOffsets = *sourceOffsetsAttr; + ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); + ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatStrides = *stridesAttr; + int64_t rank = resultType.getRank(); + if (failed(validateFragmentAssemblyMetadata(reconciliator, + rank, + 1 + reconciliator.getFragments().size(), + operandIndices, + sourceOffsets, + flatOffsets, + flatSizes, + flatStrides))) + return ReturnPathLoweringResult::Failure; + for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { + if (operandIndices[fragmentIndex] != static_cast(operandNumber)) + continue; + + SmallVector fragmentOffsets; + SmallVector fragmentSizes; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + if (flatStrides[flatIndex] != 1) { + reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); + return ReturnPathLoweringResult::Failure; + } + fragmentOffsets.push_back(flatOffsets[flatIndex]); + fragmentSizes.push_back(flatSizes[flatIndex]); + } + + bool failedChunk = false; + if (failed(forEachContiguousDestinationChunk( + outputType.getShape(), + fragmentOffsets, + fragmentSizes, + [&](ArrayRef chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { + auto hostOffset = + getCheckedByteOffset(computeFlatElementIndex(chunkOffsets, outputType.getShape()), + elementSize, + producerOp, + "fragment assembly host offset"); + auto sourceOffset = getCheckedByteOffset(sourceOffsets[fragmentIndex] + relativeSourceOffset, + elementSize, + producerOp, + "fragment assembly source offset"); + auto fragmentBytes = + getCheckedByteOffset(chunkElements, elementSize, producerOp, "fragment assembly host copy byte size"); + if (failed(hostOffset) || failed(sourceOffset) || failed(fragmentBytes)) { + failedChunk = true; + return failure(); + } + auto sizeAttr = + pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size"); + if (failed(sizeAttr)) { + failedChunk = true; + return failure(); + } + + outputTensor = + pim::PimMemCopyDevToHostOp::create(rewriter, + reconciliator.getLoc(), + outputTensor.getType(), + getOrCreateIndexConstant(rewriter, producerOp, *hostOffset), + getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset), + outputTensor, + storedValue, + *sizeAttr) + .getOutput(); + return success(); + }))) + failedChunk = true; + if (failedChunk) + return ReturnPathLoweringResult::Failure; + } + markOpToRemove(reconciliator.getOperation()); + } + return ReturnPathLoweringResult::Handled; + } + if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); auto storedByteSize = @@ -669,6 +813,16 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret return; } + if (auto reconciliator = dyn_cast(op)) { + std::optional mode = reconciliator.getMode(); + if (mode && *mode == "fragment_assembly") { + markOpToRemove(reconciliator.getOperation()); + for (Value operand : reconciliator->getOperands()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + return; + } + } + if (auto computeOp = dyn_cast(op)) { markOpToRemove(computeOp); if (!computeOp.getInputs().empty()) diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 1b7e60d..2bd92c0 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -245,6 +245,7 @@ def SpatReconciliatorOp : SpatOp<"reconciliator", []> { StrAttr:$indexMap, OptionalAttr:$mode, OptionalAttr:$fragmentOperandIndices, + OptionalAttr:$fragmentSourceOffsets, OptionalAttr:$fragmentStrides, OptionalAttr:$conflictPolicy, OptionalAttr:$coveragePolicy diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 7ea4dde..d4cbbfe 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -483,21 +483,28 @@ LogicalResult SpatReconciliatorOp::verify() { return failure(); if (!getFragments().empty()) return emitError("legacy reconciliator does not accept extra fragment operands"); - if (getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr()) + if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr() + || getCoveragePolicyAttr()) return emitError("legacy reconciliator does not accept fragment assembly attributes"); return success(); } auto stridesAttr = getFragmentStridesAttr(); auto operandIndicesAttr = getFragmentOperandIndicesAttr(); + auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr(); if (!operandIndicesAttr) return emitError("fragment assembly reconciliator requires fragment operand indices"); + if (!sourceOffsetsAttr) + return emitError("fragment assembly reconciliator requires fragment source offsets"); if (!stridesAttr) return emitError("fragment assembly reconciliator requires fragment strides"); ArrayRef operandIndices = operandIndicesAttr.asArrayRef(); + ArrayRef sourceOffsets = sourceOffsetsAttr.asArrayRef(); ArrayRef strides = stridesAttr.asArrayRef(); if (strides.size() != offsets.size()) return emitError("fragment stride and offset arrays must have the same length"); + if (sourceOffsets.size() != operandIndices.size()) + return emitError("fragment source offset count must match fragment operand index count"); if (!getConflictPolicyAttr() || !getCoveragePolicyAttr()) return emitError("fragment assembly reconciliator requires conflict and coverage policies"); if (getConflictPolicy() != "disjoint") @@ -519,11 +526,21 @@ LogicalResult SpatReconciliatorOp::verify() { SmallVector, SmallVector>, 8> slices; slices.reserve(static_cast(fragmentCount)); - SmallVector, 4>, 8> sizesByOperand(static_cast(operandCount)); + SmallVector fragmentCountsByOperand(static_cast(operandCount), 0); + auto expandFlatElementIndex = [](int64_t flatIndex, ArrayRef shape) { + SmallVector indices(shape.size(), 0); + for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { + indices[dim] = flatIndex % shape[dim]; + flatIndex /= shape[dim]; + } + return indices; + }; for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; if (operandIndex < 0 || operandIndex >= operandCount) return emitError("fragment assembly operand index is out of range"); + if (sourceOffsets[fragmentIndex] < 0) + return emitError("fragment assembly source offsets must be nonnegative"); auto operandType = dyn_cast(operands[operandIndex].getType()); if (!operandType || !operandType.hasStaticShape()) @@ -541,7 +558,17 @@ LogicalResult SpatReconciliatorOp::verify() { fragmentSizes.push_back(sizes[flatIndex]); } - sizesByOperand[static_cast(operandIndex)].push_back(fragmentSizes); + ++fragmentCountsByOperand[static_cast(operandIndex)]; + int64_t fragmentElements = 1; + for (int64_t dim = 0; dim < rank; ++dim) + fragmentElements *= fragmentSizes[dim]; + if (sourceOffsets[fragmentIndex] + fragmentElements > operandType.getNumElements()) + return emitError("fragment assembly source offset exceeds the operand bounds"); + SmallVector sourceSliceOffsets = + expandFlatElementIndex(sourceOffsets[fragmentIndex], operandType.getShape()); + for (int64_t dim = 0; dim < rank; ++dim) + if (sourceSliceOffsets[dim] + fragmentSizes[dim] > operandType.getDimSize(dim)) + return emitError("fragment assembly source offset must describe a valid unit-stride slice"); for (const auto& [existingOffsets, existingSizes] : slices) { bool overlaps = true; @@ -562,28 +589,8 @@ LogicalResult SpatReconciliatorOp::verify() { } for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) { - if (sizesByOperand[static_cast(operandIndex)].empty()) + if (fragmentCountsByOperand[static_cast(operandIndex)] == 0) return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment"); - - auto operandType = cast(operands[operandIndex].getType()); - ArrayRef operandShape = operandType.getShape(); - auto& fragmentShapes = sizesByOperand[static_cast(operandIndex)]; - if (fragmentShapes.size() == 1) { - if (!llvm::equal(operandShape, fragmentShapes.front())) - return emitError("single-fragment reconciliator operand shape must match declared fragment size"); - continue; - } - - ArrayRef fragmentShape = fragmentShapes.front(); - for (ArrayRef otherShape : fragmentShapes) - if (!llvm::equal(fragmentShape, otherShape)) - return emitError("packed reconciliator operand requires equal fragment sizes per operand"); - if (llvm::equal(operandShape, fragmentShape)) - continue; - if (!llvm::equal(operandShape.drop_front(), fragmentShape.drop_front())) - return emitError("packed reconciliator operand must match fragment shape on non-packed dimensions"); - if (operandShape.front() != static_cast(fragmentShapes.size()) * fragmentShape.front()) - return emitError("packed reconciliator operand first dimension must equal fragment_count * fragment_size"); } if (getCoveragePolicy() == "complete") { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index dffb185..45b57b8 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -194,6 +194,7 @@ struct MaterializedClass { SmallVector weights; SmallVector inputs; SmallVector hostOutputs; + DenseMap publicationOutputToResultIndex; DenseMap weightArgs; DenseMap inputArgs; DenseMap hostOutputToResultIndex; @@ -307,11 +308,9 @@ struct PendingProjectedHostOutputFragment { Value originalOutput; ClassId sourceClass = 0; ProducerKey producerKey; - Value operand; - RankedTensorType operandType; - RankedTensorType fragmentType; - int64_t packedFragmentIndex = -1; - int64_t currentLane = -1; + Value publicationValue; + int64_t sourceFragmentOrdinal = 0; + int64_t sourceElementOffset = 0; SmallVector offsets; SmallVector sizes; SmallVector strides; @@ -379,6 +378,9 @@ LogicalResult localizeCapturesInClonedOp(MaterializerState& state, Operation& clonedOp, IRMapping* mapper = nullptr); LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass); +void createDim0ParallelInsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset); +Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc); bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, const AffineProjectedInputSliceMatch& match, ProducerKey producer, @@ -627,6 +629,31 @@ ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, Co return logicalInstance; } +FailureOr +getPublicationLaneForProducerKey(MaterializerState& state, const MaterializedClass& sourceClass, ProducerKey key) { + if (!sourceClass.isBatch) + return 0; + + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, key.instance); + auto cpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (cpuIt == state.schedule.computeToCpuMap.end()) { + sourceClass.op->emitError("projected packed host publication could not resolve the producer CPU for a publication lane") + << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount + << " resultIndex=" << key.resultIndex; + return failure(); + } + + auto laneIt = sourceClass.cpuToLane.find(cpuIt->second); + if (laneIt == sourceClass.cpuToLane.end()) { + sourceClass.op->emitError("projected packed host publication could not map a producer key to a publication lane") + << " cpu=" << cpuIt->second << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount + << " resultIndex=" << key.resultIndex; + return failure(); + } + + return laneIt->second; +} + SmallVector collectProducerKeysForDestinations(Value value, std::optional logicalConsumer = std::nullopt) { // Destination collection works in the materializer's logical one-lane key domain. @@ -1043,6 +1070,7 @@ LogicalResult collectHostOutputs(MaterializerState& state) { for (MaterializedClass& materializedClass : state.classes) { materializedClass.hostOutputs.clear(); materializedClass.hostOutputToResultIndex.clear(); + materializedClass.publicationOutputToResultIndex.clear(); } state.hostOutputOwners.clear(); @@ -1150,48 +1178,6 @@ void setInsertionPointForNewMaterializedOp(MaterializerState& state) { state.rewriter.setInsertionPointToEnd(&funcBlock); } -FailureOr createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) { - DenseSet usedCpus; - for (const auto& [cpu, _] : state.cpuToClass) - usedCpus.insert(cpu); - - CpuId assemblyCpu = 0; - while (usedCpus.contains(assemblyCpu)) - ++assemblyCpu; - - setInsertionPointForNewMaterializedOp(state); - - auto resultType = dyn_cast(originalOutput.getType()); - if (!resultType || !resultType.hasStaticShape()) - return state.func.emitError("projected host assembly class requires a static ranked tensor output"); - - auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {}); - compute.getProperties().setOperandSegmentSizes({0, 0}); - auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id"); - if (failed(coreIdAttr)) - return failure(); - compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); - - Block* body = state.rewriter.createBlock(&compute.getBody()); - state.rewriter.setInsertionPointToEnd(body); - Value placeholder = - tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult(); - SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder}); - state.rewriter.setInsertionPointAfter(compute.getOperation()); - - MaterializedClass materializedClass; - materializedClass.id = state.classes.size(); - materializedClass.cpus.push_back(assemblyCpu); - materializedClass.op = compute.getOperation(); - materializedClass.body = body; - materializedClass.hostOutputToResultIndex[originalOutput] = 0; - materializedClass.hostOutputs.push_back(originalOutput); - state.cpuToClass[assemblyCpu] = materializedClass.id; - state.hostOutputOwners[originalOutput] = materializedClass.id; - state.classes.push_back(std::move(materializedClass)); - return state.classes.back().id; -} - BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { auto it = materializedClass.weightArgs.find(weight); if (it != materializedClass.weightArgs.end()) @@ -1226,13 +1212,125 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali materializedClass.inputArgs[input] = std::get<1>(*arg); return std::get<1>(*arg); } - if (auto compute = dyn_cast(materializedClass.op)) { - auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); - assert(arg && "expected compute_batch body while inserting an input argument"); - materializedClass.inputArgs[input] = std::get<1>(*arg); - return std::get<1>(*arg); + + auto compute = cast(materializedClass.op); + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute_batch body while inserting an input argument"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); +} + +void refreshPendingProjectedHostOutputPublicationValues(MaterializerState& state, + Operation* oldOwner, + Operation* newOwner) { + if (!oldOwner || oldOwner == newOwner) + return; + + for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) { + auto publicationResult = dyn_cast_or_null(fragment.publicationValue); + if (!publicationResult || publicationResult.getOwner() != oldOwner) + publicationResult = OpResult(); + else + fragment.publicationValue = newOwner->getResult(publicationResult.getResultNumber()); + + if (auto originalResult = dyn_cast_or_null(fragment.originalOutput); originalResult + && originalResult.getOwner() == oldOwner) { + fragment.originalOutput = newOwner->getResult(originalResult.getResultNumber()); + } + + if (fragment.producerKey.instance.op == oldOwner) + fragment.producerKey.instance.op = newOwner; } - llvm_unreachable("Cannot reach here"); +} + +FailureOr appendScalarPublicationResult(MaterializerState& state, + MaterializedClass& materializedClass, + Value payload, + Location loc) { + auto existing = materializedClass.publicationOutputToResultIndex.find(payload); + if (existing != materializedClass.publicationOutputToResultIndex.end()) + return materializedClass.op->getResult(existing->second); + + auto compute = dyn_cast(materializedClass.op); + if (!compute) + return materializedClass.op->emitError("scalar publication result requires spat.scheduled_compute owner"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape()) + return materializedClass.op->emitError("scalar publication result requires static ranked tensor payload"); + + FailureOr> inserted = + compute.insertOutput(state.rewriter, compute.getNumResults(), payloadType, loc); + if (failed(inserted)) + return materializedClass.op->emitError("failed to append scalar publication result"); + + Operation* oldOp = materializedClass.op; + auto [result, newCompute] = *inserted; + materializedClass.op = newCompute.getOperation(); + materializedClass.body = &newCompute.getBody().front(); + refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op); + materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber(); + + auto yieldOp = dyn_cast(materializedClass.body->getTerminator()); + if (!yieldOp) + return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result"); + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); }); + return result; +} + +FailureOr appendBatchPublicationResult(MaterializerState& state, + MaterializedClass& materializedClass, + Value payload, + Location loc) { + auto existing = materializedClass.publicationOutputToResultIndex.find(payload); + if (existing != materializedClass.publicationOutputToResultIndex.end()) + return materializedClass.op->getResult(existing->second); + + auto batch = dyn_cast(materializedClass.op); + if (!batch) + return materializedClass.op->emitError("batch publication result requires spat.scheduled_compute_batch owner"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return materializedClass.op->emitError( + "batch publication result requires a static ranked tensor payload with rank > 0"); + + SmallVector publishedShape(payloadType.getShape()); + publishedShape[0] *= static_cast(materializedClass.cpus.size()); + auto publishedType = + RankedTensorType::get(publishedShape, payloadType.getElementType(), payloadType.getEncoding()); + + FailureOr> inserted = + batch.insertOutput(state.rewriter, batch.getNumResults(), publishedType, loc); + if (failed(inserted)) + return materializedClass.op->emitError("failed to append batch publication result"); + + Operation* oldOp = materializedClass.op; + auto [result, outputArg, newBatch] = *inserted; + materializedClass.op = newBatch.getOperation(); + materializedClass.body = &newBatch.getBody().front(); + refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op); + materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber(); + + auto inParallelOp = dyn_cast(materializedClass.body->getTerminator()); + auto laneArg = newBatch.getLaneArgument(); + if (!laneArg) + return materializedClass.op->emitError("batch publication result requires a lane argument"); + if (!inParallelOp) { + auto yieldOp = dyn_cast(materializedClass.body->getTerminator()); + if (!yieldOp || yieldOp.getNumOperands() != 0) + return materializedClass.op->emitError( + "batch publication result requires either spat.in_parallel or an empty spat.yield terminator"); + state.rewriter.setInsertionPoint(yieldOp); + inParallelOp = SpatInParallelOp::create(state.rewriter, loc); + state.rewriter.eraseOp(yieldOp); + } + + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + Value firstOffset = + scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc); + createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset); + return result; } // ----------------------------------------------------------------------------- @@ -5520,6 +5618,12 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat return failure(); } + FailureOr publicationResult = appendScalarPublicationResult(state, sourceClass, packed, loc); + if (failed(publicationResult)) + return failure(); + + int64_t fragmentElementCount = fragmentType.getNumElements(); + for (auto [runIndex, slot] : llvm::enumerate(run)) { if (slot.peers.size() != 1) { sourceClass.op->emitError("projected scalar host output publication expects scalar one-peer run slots"); @@ -5553,11 +5657,9 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat originalOutput, sourceClass.id, ProducerKey {peer, resultIndex}, - packed, - cast(packed.getType()), - fragmentType, - static_cast(runIndex), + *publicationResult, static_cast(runIndex), + static_cast(runIndex) * fragmentElementCount, SmallVector(*offsets), SmallVector(*sizes), SmallVector(*strides), @@ -5609,7 +5711,10 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt if (fragmentType == originalOutput.getType()) return false; - bool operandIsDim0Packed = false; + FailureOr publicationResult = appendBatchPublicationResult(state, sourceClass, packed, loc); + if (failed(publicationResult)) + return failure(); + if (packedType != fragmentType) { if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast(keys.size()) != 0) return sourceClass.op->emitError( @@ -5622,9 +5727,16 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt return sourceClass.op->emitError( "projected packed host publication fragment shape does not match projected slice size") << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); - operandIsDim0Packed = true; } + int64_t payloadElementCount = packedType.getNumElements(); + int64_t fragmentElementCount = fragmentType.getNumElements(); + int64_t fragmentsPerPublishedPayload = payloadElementCount / fragmentElementCount; + if (fragmentsPerPublishedPayload <= 0 || static_cast(keys.size()) % fragmentsPerPublishedPayload != 0) + return sourceClass.op->emitOpError( + "projected packed host publication requires a deterministic publication packing layout") + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + for (auto [fragmentIndex, key] : llvm::enumerate(keys)) { if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex || key.instance.laneCount != 1) return sourceClass.op->emitError("projected packed host publication requires one-lane keys from one producer result"); @@ -5642,15 +5754,19 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt return sourceClass.op->emitError( "projected packed host publication requires one operand to map to a consistent fragment shape"); + FailureOr publishedLaneIndex = getPublicationLaneForProducerKey(state, sourceClass, key); + if (failed(publishedLaneIndex)) + return failure(); + int64_t localFragmentOffsetWithinPublishedPayload = + (static_cast(fragmentIndex) % fragmentsPerPublishedPayload) * fragmentElementCount; + state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { originalOutput, sourceClass.id, key, - packed, - packedType, - fragmentType, - operandIsDim0Packed ? static_cast(fragmentIndex) : -1, + *publicationResult, static_cast(fragmentIndex), + static_cast(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload, SmallVector(*offsets), SmallVector(*sizes), SmallVector(*strides), @@ -5678,24 +5794,23 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { < reinterpret_cast(rhs.getAsOpaquePointer()); }); + auto returnOp = dyn_cast(state.func.getBody().front().getTerminator()); + if (!returnOp) + return state.func.emitError("expected func.return terminator while finalizing projected host output fragments"); + for (Value originalOutput : outputs) { - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) { - Operation* anchor = originalOutput.getDefiningOp() ? originalOutput.getDefiningOp() : state.func.getOperation(); - return anchor->emitError("missing host owner for projected host output fragments"); - } - - MaterializedClass* ownerClass = &state.classes[ownerIt->second]; - auto resultType = dyn_cast(originalOutput.getType()); if (!resultType || !resultType.hasStaticShape()) - return ownerClass->op->emitError("projected host output must have static ranked tensor type"); + return state.func.emitError("projected host output must have static ranked tensor type"); SmallVector& fragments = byOutput[originalOutput]; llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, const PendingProjectedHostOutputFragment* rhs) { - if (lhs->sourceLane != rhs->sourceLane) - return lhs->sourceLane < rhs->sourceLane; + if (lhs->publicationValue != rhs->publicationValue) + return reinterpret_cast(lhs->publicationValue.getAsOpaquePointer()) + < reinterpret_cast(rhs->publicationValue.getAsOpaquePointer()); + if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) + return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; if (lhs->sourceClass != rhs->sourceClass) return lhs->sourceClass < rhs->sourceClass; return std::lexicographical_compare(lhs->offsets.begin(), @@ -5704,240 +5819,36 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { rhs->offsets.end()); }); - bool allFromSameSourceClass = - llvm::all_of(fragments, [&](const PendingProjectedHostOutputFragment* fragment) { - return fragment->sourceClass == fragments.front()->sourceClass; - }); - if (allFromSameSourceClass) { - ownerClass = &state.classes[fragments.front()->sourceClass]; - state.hostOutputOwners[originalOutput] = ownerClass->id; - } else { - if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput)) - goto owner_selected; - FailureOr createdOwner = - createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc); - if (failed(createdOwner)) - return failure(); - ownerClass = &state.classes[*createdOwner]; - } -owner_selected: - - if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) { - auto sourceBatch = dyn_cast(fragments.front()->producerKey.instance.op); - auto batch = dyn_cast(ownerClass->op); - auto inParallelOp = dyn_cast_or_null(ownerClass->body->getTerminator()); - auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput); - if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end()) - return ownerClass->op->emitError("missing batch host assembly state for projected host output"); - FailureOr sourceProjection = - getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex); - std::optional sourceLaneArg = sourceBatch.getLaneArgument(); - if (failed(sourceProjection) || !sourceLaneArg) - return ownerClass->op->emitError( - "direct batch host output assembly requires the source batch projection metadata"); - - auto outputArg = batch.getOutputArgument(resultIt->second); - auto laneArg = batch.getLaneArgument(); - if (!outputArg || !laneArg) - return ownerClass->op->emitError("missing compute_batch output block argument for projected host output"); - - if (fragments.size() != ownerClass->cpus.size()) - return ownerClass->op->emitError( - "direct batch host output assembly expects exactly one fragment per materialized lane"); - - SmallVector fragmentsByLane(ownerClass->cpus.size(), nullptr); - for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { - int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane; - if (currentLane < 0 || currentLane >= static_cast(fragmentsByLane.size())) - return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds"); - if (fragmentsByLane[currentLane]) - return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane"); - fragmentsByLane[currentLane] = fragmentRecord; - } - - if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; })) - return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes"); - - FailureOr> firstSizes = - evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane); - FailureOr> firstStrides = - evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane); - if (failed(firstSizes) || failed(firstStrides)) - return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape"); - SmallVector referenceSizes(*firstSizes); - SmallVector referenceStrides(*firstStrides); - Value laneOperand; - for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) { - FailureOr> fragmentSizes = - evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane); - FailureOr> fragmentStrides = - evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane); - if (failed(fragmentSizes) || failed(fragmentStrides)) - return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape"); - if (SmallVector(*fragmentSizes) != referenceSizes - || SmallVector(*fragmentStrides) != referenceStrides) - return ownerClass->op->emitError( - "direct batch host output assembly expects a uniform fragment shape and strides"); - - MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; - Value operand; - if (std::optional availableValue = - state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) { - operand = *availableValue; - } else { - operand = fragmentRecord->operand; - } - if (!isValueLegalInMaterializedClassBody(operand, *ownerClass)) - return ownerClass->op->emitError( - "projected batch host output assembly requires source-local fragment operands"); - if (laneOperand && laneOperand != operand) - return ownerClass->op->emitError( - "direct batch host output assembly expects one shared lane-local fragment producer"); - laneOperand = operand; - } - - SmallVector mixedOffsets; - mixedOffsets.reserve(referenceSizes.size()); - for (size_t dim = 0; dim < referenceSizes.size(); ++dim) { - SmallVector offsetsByLane; - offsetsByLane.reserve(fragmentsByLane.size()); - for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) { - FailureOr> fragmentOffsets = - evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane); - if (failed(fragmentOffsets)) - return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets"); - offsetsByLane.push_back((*fragmentOffsets)[dim]); - } - mixedOffsets.push_back(allEqual(offsetsByLane) - ? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front())) - : OpFoldResult(createLaneIndexedIndexValue( - state, *ownerClass, ArrayRef(offsetsByLane), fragments.front()->loc))); - } - - state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second); - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - tensor::ParallelInsertSliceOp::create(state.rewriter, - fragments.front()->loc, - laneOperand, - *outputArg, - mixedOffsets, - getStaticIndexAttrs(state.rewriter, referenceSizes), - getStaticIndexAttrs(state.rewriter, referenceStrides)); - continue; - } - - state.rewriter.setInsertionPoint(ownerClass->body->getTerminator()); + state.rewriter.setInsertionPoint(returnOp); Location loc = fragments.front()->loc; SmallVector reconciliatorOperands; SmallVector fragmentOperandIndices; + SmallVector fragmentSourceOffsets; SmallVector flatOffsets; SmallVector flatSizes; SmallVector flatStrides; DenseMap operandIndicesByValue; - DenseSet emittedBatchForwarding; for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { - MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; - Value operand; - - if (std::optional availableValue = - state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) { - operand = *availableValue; - } else if (fragmentRecord->sourceClass == sourceClass.id) { - operand = fragmentRecord->operand; - } else { - return sourceClass.op->emitError( - "projected host output fragment assembly is missing source-visible fragment operands before finalization"); - } - - if (fragmentRecord->sourceClass != ownerClass->id) { - if (sourceClass.isBatch && !ownerClass->isBatch) { - if (!emittedBatchForwarding.insert(sourceClass.id).second) { - std::optional localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id); - if (!localized) - return ownerClass->op->emitError( - "projected host output fragment assembly is missing forwarded batch fragments"); - operand = *localized; - } else { - SmallVector forwardedKeys; - forwardedKeys.reserve(sourceClass.cpus.size()); - Value forwardedPayload = fragmentRecord->operand; - for (PendingProjectedHostOutputFragment* candidate : fragments) { - if (candidate->sourceClass != sourceClass.id) - continue; - if (candidate->operand != forwardedPayload) - return ownerClass->op->emitError( - "projected host output batch forwarding expects one shared batch payload per source class"); - forwardedKeys.push_back(candidate->producerKey); - } - llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) { - return lhs.instance.laneStart < rhs.instance.laneStart; - }); - if (failed(emitClassToClassCommunication( - state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc))) - return failure(); - std::optional localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id); - if (!localized) - return ownerClass->op->emitError( - "projected host output fragment assembly failed to recover forwarded batch fragment"); - operand = *localized; - } - } else { - MessageVector messages; - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, - sourceClass.cpus.front(), - "projected host output source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass->op, - ownerClass->cpus.front(), - "projected host output target core id"); - if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc))) - return failure(); - operand = appendReceive(state, - *ownerClass, - cast(operand.getType()), - messages, - fragmentRecord->loc); - } - } else if (!ownerClass->isBatch) { - FailureOr localOperand = materializeTensorValueForMaterializedClassUse( - state, - *ownerClass, - operand, - ownerClass->op, - "projected host output assembly tried to reuse a non-local fragment tensor"); - if (failed(localOperand)) - return failure(); - operand = *localOperand; - } + Value operand = fragmentRecord->publicationValue; auto [operandIt, inserted] = operandIndicesByValue.try_emplace(operand, static_cast(reconciliatorOperands.size())); if (inserted) reconciliatorOperands.push_back(operand); fragmentOperandIndices.push_back(operandIt->second); + fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset); llvm::append_range(flatOffsets, fragmentRecord->offsets); llvm::append_range(flatSizes, fragmentRecord->sizes); llvm::append_range(flatStrides, fragmentRecord->strides); auto operandType = dyn_cast(operand.getType()); if (!operandType || !operandType.hasStaticShape()) - return ownerClass->op->emitError("projected host output assembly requires static ranked tensor operands"); - if (fragmentRecord->packedFragmentIndex >= 0) { - int64_t fragmentSize0 = fragmentRecord->fragmentType.getDimSize(0); - if (fragmentSize0 <= 0 || operandType.getRank() == 0) - return ownerClass->op->emitError("packed projected host output assembly requires ranked fragment operands"); - int64_t start = fragmentRecord->packedFragmentIndex * fragmentSize0; - int64_t end = start + fragmentSize0; - if (start < 0 || end > operandType.getDimSize(0)) - return ownerClass->op->emitError("packed projected host output fragment index is out of bounds"); - } + return state.func.emitError("projected host output assembly requires static ranked tensor operands"); } if (reconciliatorOperands.empty()) - return ownerClass->op->emitError("missing projected host output fragments"); + return state.func.emitError("missing projected host output fragments"); Value input = reconciliatorOperands.front(); ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front(); @@ -5954,12 +5865,12 @@ owner_selected: state.rewriter.getStringAttr("identity"), state.rewriter.getStringAttr("fragment_assembly"), state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices), + state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets), state.rewriter.getDenseI64ArrayAttr(flatStrides), state.rewriter.getStringAttr("disjoint"), state.rewriter.getStringAttr("complete")); - if (failed(setHostOutputValue(state, *ownerClass, originalOutput, reconciliator.getOutput()))) - return failure(); + state.hostReplacements[originalOutput] = reconciliator.getOutput(); } return success();