diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 3996742..0c03134 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" @@ -96,11 +97,35 @@ std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex); std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, tensor::ExtractSliceOp extract); +LogicalResult recordLocalPackedSlotValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Value packed); FailureOr materializeProjectedWholeBatchExtractReplacement(MaterializerState& state, MaterializedClass& targetClass, tensor::ExtractSliceOp extract, ProducerKey producer, IRMapping* mapper = nullptr); +enum class MaterializedLayoutKind { + DenseLogical, + RowPackedNCHWFromRows, + Unknown +}; + +struct MaterializedTensorView { + Value value; + RankedTensorType logicalType; + RankedTensorType physicalType; + MaterializedLayoutKind layout = MaterializedLayoutKind::Unknown; +}; + +FailureOr materializeLogicalTensorView(MaterializerState& state, + MaterializedClass& targetClass, + MaterializedTensorView view, + Operation* anchor, + StringRef context, + std::optional producer = std::nullopt, + IRMapping* mapper = nullptr); bool isConstantLike(Value value) { Operation* definingOp = value.getDefiningOp(); return definingOp && definingOp->hasTrait(); @@ -1484,6 +1509,66 @@ FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState return appendInput(state, targetClass, value); } +MaterializedLayoutKind classifyMaterializedTensorLayout(RankedTensorType logicalType, RankedTensorType physicalType) { + if (logicalType == physicalType) + return MaterializedLayoutKind::DenseLogical; + + if (!logicalType || !physicalType || !logicalType.hasStaticShape() || !physicalType.hasStaticShape()) + return MaterializedLayoutKind::Unknown; + + if (logicalType.getRank() != 4 || physicalType.getRank() != 4) + return MaterializedLayoutKind::Unknown; + if (logicalType.getElementType() != physicalType.getElementType() + || logicalType.getEncoding() != physicalType.getEncoding()) + return MaterializedLayoutKind::Unknown; + + if (logicalType.getDimSize(0) != 1 || physicalType.getDimSize(2) != 1) + return MaterializedLayoutKind::Unknown; + if (logicalType.getDimSize(1) != physicalType.getDimSize(1) + || logicalType.getDimSize(2) != physicalType.getDimSize(0) + || logicalType.getDimSize(3) != physicalType.getDimSize(3)) + return MaterializedLayoutKind::Unknown; + + return MaterializedLayoutKind::RowPackedNCHWFromRows; +} + +FailureOr materializeLogicalTensorView(MaterializerState& state, + MaterializedClass& targetClass, + MaterializedTensorView view, + Operation* anchor, + StringRef context, + std::optional producer, + IRMapping* mapper) { + FailureOr localizedValue = materializeTensorValueForMaterializedClassUse( + state, targetClass, view.value, anchor, context, producer, mapper); + if (failed(localizedValue)) + return failure(); + + if (view.layout == MaterializedLayoutKind::DenseLogical) + return *localizedValue; + + if (view.layout == MaterializedLayoutKind::RowPackedNCHWFromRows) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value empty = tensor::EmptyOp::create(state.rewriter, + anchor->getLoc(), + view.logicalType.getShape(), + view.logicalType.getElementType()) + .getResult(); + auto permutation = DenseI64ArrayAttr::get(anchor->getContext(), {2, 1, 0, 3}); + auto transpose = + linalg::TransposeOp::create(state.rewriter, anchor->getLoc(), *localizedValue, empty, permutation); + return transpose.getResult().front(); + } + + InFlightDiagnostic diagnostic = anchor->emitError("materialized tensor replacement changed physical layout") + << " logicalType=" << view.logicalType << " physicalType=" << view.physicalType; + if (producer) { + diagnostic << " producerOp='" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + } + return failure(); +} + std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, Operation* anchor, BlockArgument externalArg) { @@ -1870,6 +1955,58 @@ Value getPackedSliceForDynamicRunIndex( return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } +FailureOr materializeIndexedBatchRunPackedValue(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + ProducerKey key, + Location loc) { + if (!run.packed) + return failure(); + + size_t flattenedIndexBase = 0; + for (const PackedScalarRunSlot& slot : run.slots) { + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); + auto keyIt = llvm::find(slot.keys, key); + if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key))) { + flattenedIndexBase += slot.keys.size(); + continue; + } + + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + run.packed, + targetClass.op, + "indexed batch run packed resolution tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + + if (!contiguousKey) { + size_t globalFragmentIndex = flattenedIndexBase + static_cast(std::distance(slot.keys.begin(), keyIt)); + return getPackedSliceForRunIndex(state, targetClass.op, *packed, run.fragmentType, globalFragmentIndex, loc); + } + + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return failure(); + + int64_t rowOffset = static_cast(flattenedIndexBase) * run.fragmentType.getDimSize(0); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); + Value slotPacked = + createDim0ExtractSlice(state, loc, *packed, firstOffset, (*slotPackedType).getDimSize(0)); + + if (*contiguousKey == key) + return slotPacked; + + std::optional sliced = extractPackedProducerSlice(state, targetClass, *contiguousKey, slotPacked, key); + if (!sliced) + return failure(); + return *sliced; + } + + return failure(); +} + using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; @@ -2017,11 +2154,15 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) continue; + size_t flattenedIndexBase = 0; for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); auto keyIt = llvm::find(slot.keys, key); if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key))) + { + flattenedIndexBase += slot.keys.size(); continue; + } MaterializedClass& materializedClass = state.classes[classId]; state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); @@ -2032,11 +2173,12 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta return std::nullopt; if (!contiguousKey) { + size_t globalFragmentIndex = flattenedIndexBase + static_cast(std::distance(slot.keys.begin(), keyIt)); Value sliced = getPackedSliceForRunIndex(state, materializedClass.op, *packed, run.fragmentType, - static_cast(std::distance(slot.keys.begin(), keyIt)), + globalFragmentIndex, (*packed).getLoc()); record(key, classId, sliced); return sliced; @@ -2046,8 +2188,10 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta if (failed(slotPackedType)) return std::nullopt; + int64_t rowOffset = static_cast(flattenedIndexBase) * run.fragmentType.getDimSize(0); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); Value slotPacked = - getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); + createDim0ExtractSlice(state, (*packed).getLoc(), *packed, firstOffset, (*slotPackedType).getDimSize(0)); if (*contiguousKey == key) { record(key, classId, slotPacked); @@ -3862,8 +4006,10 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, Value payload, Location loc) { if (sourceClass.id == targetClass.id) { - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, payload); + if (keys.size() == 1) + state.availableValues.record(keys.front(), targetClass.id, payload); + else if (failed(recordLocalPackedSlotValue(state, targetClass, keys, payload))) + return failure(); return success(); } @@ -4080,8 +4226,10 @@ LogicalResult emitOutputFanout(MaterializerState& state, if (!recordedProjectedHostFragments && failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) return failure(); - for (ProducerKey key : keys) - state.availableValues.record(key, sourceClass.id, payload); + if (keys.size() == 1) + state.availableValues.record(keys.front(), sourceClass.id, payload); + else if (failed(recordLocalPackedSlotValue(state, sourceClass, keys, payload))) + return failure(); return success(); } @@ -4108,7 +4256,7 @@ struct WholeBatchFragmentGroup { SmallVector sourceLanes; Value packed; RankedTensorType slotPackedType; - SmallVector slotIndices; + SmallVector packedRowOffsets; SmallVector, 16> directFragments; SmallVector redundantReceives; }; @@ -4232,13 +4380,10 @@ FailureOr extractPackedSlotForIndex(MaterializerState& state, MaterializedClass& targetClass, Value packed, RankedTensorType slotPackedType, - Value slotIndex, + Value packedRowOffset, Location loc) { - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0)); + return createDim0ExtractSliceInClass( + state, targetClass, loc, packed, packedRowOffset, slotPackedType.getDimSize(0)); } SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { @@ -4655,7 +4800,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, if (failed(slotPackedType)) return failure(); WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType); - group.slotIndices.push_back(slotIndex); + group.packedRowOffsets.push_back(static_cast(flattenedIndexBase) * plan.rowsPerLane); group.outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); flattenedIndexBase += slot.keys.size(); continue; @@ -4663,7 +4808,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType); for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) { - group.slotIndices.push_back(flattenedIndexBase + keyIndex); + group.packedRowOffsets.push_back(static_cast(flattenedIndexBase + keyIndex) * plan.rowsPerLane); group.outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); } flattenedIndexBase += slot.keys.size(); @@ -4795,9 +4940,10 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, state, targetClass, destination, - static_cast(group.slotIndices.size()), + static_cast(group.packedRowOffsets.size()), [&](Value flatIndex) -> FailureOr { - Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); + Value packedRowOffset = + createIndexedIndexValue(state, targetClass.op, group.packedRowOffsets, flatIndex, loc); FailureOr packed = materializeTensorValueForMaterializedClassUse( state, targetClass, @@ -4806,7 +4952,7 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, "whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); if (failed(packed)) return failure(); - return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc); + return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedRowOffset, loc); }, [&](Value flatIndex) -> FailureOr { return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); @@ -5520,28 +5666,25 @@ FailureOr resolveInputValue(MaterializerState& state, return rejectNonLocalResolvedValue(*value); if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { - size_t laneCount = targetClass.cpus.size(); for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { if (!llvm::is_contained(slot.keys, *producer)) continue; Value received = Value(); if (indexedRun->packed) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - received = getPackedSliceForRunIndex(state, - targetClass.op, - indexedRun->packed, - indexedRun->fragmentType, - slotIndex, - consumerInstance.op->getLoc()); + FailureOr packed = + materializeIndexedBatchRunPackedValue(state, targetClass, *indexedRun, *producer, consumerInstance.op->getLoc()); + if (failed(packed)) + return failure(); + received = *packed; } else { + size_t laneCount = targetClass.cpus.size(); MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); received = appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); } - for (ProducerKey slotKey : slot.keys) - state.availableValues.record(slotKey, targetClass.id, received); + state.availableValues.record(*producer, targetClass.id, received); return rejectNonLocalResolvedValue(received); } } @@ -5661,13 +5804,31 @@ LogicalResult mapInputs(MaterializerState& state, const ComputeInstance& instance, IRMapping& mapper, CloneIndexingContext indexing) { - auto mapResolvedInput = [&](Value resolved) -> FailureOr { - return materializeTensorValueForMaterializedClassUse( - state, - targetClass, - resolved, - targetClass.op, - "input mapping tried to reuse a tensor from another materialized class"); + auto mapResolvedInput = [&](Value resolved, + Value logicalInput, + Operation* anchor, + std::optional producer) -> FailureOr { + auto logicalType = dyn_cast(logicalInput.getType()); + auto physicalType = dyn_cast(resolved.getType()); + if (!logicalType || !physicalType || !logicalType.hasStaticShape() || !physicalType.hasStaticShape()) + return materializeTensorValueForMaterializedClassUse( + state, + targetClass, + resolved, + anchor, + "input mapping tried to reuse a tensor from another materialized class", + producer); + + return materializeLogicalTensorView(state, + targetClass, + MaterializedTensorView { + .value = resolved, + .logicalType = logicalType, + .physicalType = physicalType, + .layout = classifyMaterializedTensorLayout(logicalType, physicalType)}, + anchor, + "input mapping tried to reuse a tensor from another materialized class", + producer); }; Operation* op = instance.op; @@ -5686,14 +5847,16 @@ LogicalResult mapInputs(MaterializerState& state, auto inputArg = compute.getInputArgument(index); if (!inputArg) return compute.emitOpError("expected compute input block argument while materializing inputs"); - FailureOr remapped = mapResolvedInput(*mapped); + std::optional producer = getInputRequestProducerKey(input, instance); + FailureOr remapped = mapResolvedInput(*mapped, input, compute, producer); if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic( - compute, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); + if (isTensorValueDefinedInDifferentMaterializedClass(*mapped, targetClass)) + emitNonLocalMaterializedClassValueDiagnostic( + compute, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + producer); return failure(); } mapper.map(*inputArg, *remapped); @@ -5733,13 +5896,37 @@ LogicalResult mapInputs(MaterializerState& state, auto inputArg = batch.getInputArgument(index); if (!inputArg) return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); - FailureOr remapped = mapResolvedInput(*mapped); + std::optional producer = getInputRequestProducerKey(input, instance); + if (demand->kind == BatchInputDemandKind::ProjectedFragment) { + std::optional match = + getProjectedInputSliceMatch(state, batch, static_cast(index)); + auto mappedType = dyn_cast((*mapped).getType()); + if (!match || !mappedType || mappedType != match->fragmentType) { + InFlightDiagnostic diagnostic = batch.emitOpError( + "projected compute_batch input resolved to an incompatible fragment") + << " #" << index << " resolvedType=" << (*mapped).getType(); + if (match) + diagnostic << " expectedFragmentType=" << match->fragmentType; + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() + << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount + << " resultIndex=" << producer->resultIndex; + } + return failure(); + } + + mapper.map(*inputArg, *mapped); + continue; + } + FailureOr remapped = mapResolvedInput(*mapped, input, batch, producer); if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic(batch, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); + if (isTensorValueDefinedInDifferentMaterializedClass(*mapped, targetClass)) + emitNonLocalMaterializedClassValueDiagnostic(batch, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + producer); return failure(); } mapper.map(*inputArg, *remapped); @@ -6605,6 +6792,59 @@ LogicalResult registerPackedRunValue(MaterializerState& state, return success(); } +FailureOr getPackedSlotFragmentType(Value packed, size_t keyCount) { + auto packedType = dyn_cast(packed.getType()); + if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0 || keyCount == 0) + return failure(); + + int64_t packedRows = packedType.getDimSize(0); + if (packedRows % static_cast(keyCount) != 0) + return failure(); + + SmallVector fragmentShape(packedType.getShape()); + fragmentShape[0] = packedRows / static_cast(keyCount); + return RankedTensorType::get(fragmentShape, packedType.getElementType(), packedType.getEncoding()); +} + +LogicalResult recordLocalPackedSlotValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Value packed) { + if (keys.empty()) + return success(); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("local packed slot registration expects one producer result"); + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("local packed slot registration expects one lane per packed fragment"); + } + + FailureOr fragmentType = getPackedSlotFragmentType(packed, keys.size()); + if (failed(fragmentType)) { + for (ProducerKey key : keys) + state.availableValues.record(key, materializedClass.id, packed); + return success(); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.packed = packed; + packedRun.kind = PackedScalarRunKind::Materialized; + packedRun.fragmentType = *fragmentType; + + PackedScalarRunSlot slot; + llvm::append_range(slot.keys, keys); + packedRun.slots.push_back(std::move(slot)); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + LogicalResult emitPackedRunFanout(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef destinationClasses,