From aec80529ca06b1bde52dd2c1c5b4c6ac0816d2e9 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 5 Jun 2026 18:22:59 +0200 Subject: [PATCH] much faster MaterializeMergeSchedule.cpp --- .../MaterializeMergeSchedule.cpp | 473 ++++++++++++------ .../MergeComputeNodesPass.cpp | 349 +------------ 2 files changed, 338 insertions(+), 484 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 2856699..5100c2b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -117,6 +118,66 @@ struct ProducerKeyInfo { static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } }; +struct SameClassConsumerLookupKey { + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const SameClassConsumerLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct SameClassConsumerLookupKeyInfo { + static SameClassConsumerLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static SameClassConsumerLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const SameClassConsumerLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); + } + + static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { + return lhs == rhs; + } +}; + +struct WholeBatchAssemblyLookupKey { + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const WholeBatchAssemblyLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct WholeBatchAssemblyLookupKeyInfo { + static WholeBatchAssemblyLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static WholeBatchAssemblyLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); + } + + static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { + return lhs == rhs; + } +}; + using ClassSlotKey = std::pair; struct MaterializedClass { @@ -270,9 +331,36 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state class AvailableValueStore { public: - void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; } + struct ExactBatchFragmentRecord { + ProducerKey key; + Value value; + }; - void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); } + void record(ProducerKey key, ClassId classId, Value value) { + exactValues[key][classId] = value; + + auto batch = dyn_cast_or_null(key.instance.op); + if (!batch || key.instance.laneCount == 0) + return; + + WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; + SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; + for (ExactBatchFragmentRecord& record : bucket) { + if (!(record.key == key)) + continue; + record.value = value; + return; + } + bucket.push_back({key, value}); + } + + void recordPackedRun(PackedScalarRunValue run) { + size_t runIndex = packedScalarRuns.size(); + packedScalarRuns.push_back(std::move(run)); + const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; + WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; + packedRunsByProducerResultClass[lookupKey].push_back(runIndex); + } void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } std::optional lookupExact(ProducerKey key, ClassId classId) const; @@ -280,7 +368,21 @@ public: std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); - SmallVectorImpl& getPackedScalarRuns() { return packedScalarRuns; } + ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = packedRunsByProducerResultClass.find(key); + if (it == packedRunsByProducerResultClass.end()) + return {}; + return it->second; + } + + ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = exactBatchFragmentsByProducerResultClass.find(key); + if (it == exactBatchFragmentsByProducerResultClass.end()) + return {}; + return it->second; + } + + PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } private: std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); @@ -288,6 +390,10 @@ private: DenseMap, ProducerKeyInfo> exactValues; SmallVector packedScalarRuns; SmallVector indexedBatchRuns; + DenseMap, WholeBatchAssemblyLookupKeyInfo> + exactBatchFragmentsByProducerResultClass; + DenseMap, WholeBatchAssemblyLookupKeyInfo> + packedRunsByProducerResultClass; }; struct MaterializerState { @@ -296,7 +402,6 @@ struct MaterializerState { IRRewriter rewriter; OperationFolder constantFolder; int64_t& nextChannelId; - SmallVector classes; DenseMap cpuToClass; DenseMap> logicalInstancesByCpu; @@ -305,7 +410,8 @@ struct MaterializerState { DenseSet materializedLogicalSlots; DenseMap, ProducerKeyInfo> producerDestClasses; - DenseMap, ProducerKeyInfo> sameClassConsumers; + DenseMap, SameClassConsumerLookupKeyInfo> + sameClassConsumerIndex; DenseMap projectedInputMatches; DenseSet nonProjectedInputs; DenseMap liveExternalUseCache; @@ -317,7 +423,9 @@ struct MaterializerState { DenseMap hostReplacements; DenseSet oldComputeOps; - MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) + MaterializerState(func::FuncOp func, + const MergeScheduleResult& schedule, + int64_t& nextChannelId) : func(func), schedule(schedule), rewriter(func.getContext()), @@ -428,6 +536,14 @@ std::optional getContiguousProducerRangeForKeys(ArrayRef getPackedBatchTensorType(Type laneType, size_t laneCount) { auto tensorType = dyn_cast(laneType); if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) @@ -1172,14 +1288,12 @@ FailureOr materializePackedScalarRunValue(MaterializerState& state, std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { for (PackedScalarRunValue& run : packedScalarRuns) { - if (run.targetClass != classId) - continue; - if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) continue; for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { - std::optional slotKey = getContiguousProducerRangeForKeys(slot.keys); - if (!slotKey || !containsProducerKey(*slotKey, key)) + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); + if (!contiguousKey || !containsProducerKey(*contiguousKey, key)) continue; FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); @@ -1197,12 +1311,13 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta Value slotPacked = getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); - if (*slotKey == key) { + if (*contiguousKey == key) { record(key, classId, slotPacked); return slotPacked; } - std::optional sliced = extractPackedProducerSlice(state, materializedClass, *slotKey, slotPacked, key); + std::optional sliced = + extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key); if (!sliced) return std::nullopt; @@ -1216,57 +1331,45 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { for (IndexedBatchRunValue& run : indexedBatchRuns) { - if (run.targetClass != classId) + if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) continue; - if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) - continue; - - for (const PackedScalarRunSlot& slot : run.slots) - if (llvm::is_contained(slot.keys, key)) - return &run; + for (const PackedScalarRunSlot& slot : run.slots) { + if (!llvm::is_contained(slot.keys, key)) + continue; + return &run; + } } - return nullptr; } std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { - if (std::optional exact = lookupExact(key, classId)) + + if (std::optional exact = lookupExact(key, classId)) { return exact; + } if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) return packedRunValue; MaterializedClass& materializedClass = state.classes[classId]; - ProducerKey containingKey; - Value containingValue; - bool foundContainingValue = false; - - for (auto& entry : exactValues) { - ProducerKey candidateKey = entry.first; - if (!containsProducerKey(candidateKey, key)) + for (const auto& [candidateKey, classValues] : exactValues) { + if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key)) continue; - auto valueIt = entry.second.find(classId); - if (valueIt == entry.second.end()) + auto valueIt = classValues.find(classId); + if (valueIt == classValues.end()) continue; - containingKey = candidateKey; - containingValue = valueIt->second; - foundContainingValue = true; - break; + std::optional slice = + extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); + if (!slice) + return std::nullopt; + + record(key, classId, *slice); + return *slice; } - - if (!foundContainingValue) - return std::nullopt; - - std::optional slice = - extractPackedProducerSlice(state, materializedClass, containingKey, containingValue, key); - if (!slice) - return std::nullopt; - - record(key, classId, *slice); - return *slice; + return std::nullopt; } Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { @@ -1389,13 +1492,13 @@ Value createIndexedIndexValue(MaterializerState& state, bool allowExhaustiveTiledSearch) { assert(!values.empty() && "expected at least one indexed value"); - if (allEqual(values)) + if (allEqual(values)) { return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); + } if (std::optional pattern = getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) return createAffineIndexValue(state, *pattern, index, loc); - Value table = createIndexTensorConstant(state, anchor, values); return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); } @@ -1578,7 +1681,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); if (sourceClass == targetClass) { - state.sameClassConsumers[producerKey].insert(targetClass); + SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass}; + SmallVector& bucket = state.sameClassConsumerIndex[lookupKey]; + if (!llvm::is_contained(bucket, producerKey)) + bucket.push_back(producerKey); continue; } @@ -2899,11 +3005,6 @@ LogicalResult emitOutputFanout(MaterializerState& state, return success(); } -struct WholeBatchAssemblyRange { - uint32_t laneStart = 0; - uint32_t laneCount = 0; -}; - struct DirectWholeBatchFragment { ProducerKey key; Value fragment; @@ -2933,31 +3034,60 @@ struct WholeBatchFragmentGroup { struct WholeBatchAssemblyPlan { RankedTensorType resultType; int64_t rowsPerLane = 0; + uint32_t batchLaneCount = 0; + uint32_t coveredLaneCount = 0; - SmallVector coveredRanges; + SmallVector coveredLanes; SmallVector packedRuns; SmallVector directFragments; }; -bool wholeBatchRangeOverlaps(ArrayRef ranges, uint32_t laneStart, uint32_t laneCount) { - uint32_t laneEnd = laneStart + laneCount; - for (WholeBatchAssemblyRange range : ranges) { - uint32_t rangeEnd = range.laneStart + range.laneCount; - if (laneStart < rangeEnd && range.laneStart < laneEnd) - return true; - } - return false; +bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) { + return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0; } -bool wholeBatchLaneCovered(ArrayRef ranges, uint32_t lane) { - for (WholeBatchAssemblyRange range : ranges) - if (range.laneStart <= lane && lane < range.laneStart + range.laneCount) +bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { + if (laneCount == 0) + return false; + if (laneStart >= plan.coveredLanes.size()) + return false; + + uint32_t laneEnd = std::min(laneStart + laneCount, plan.coveredLanes.size()); + for (uint32_t lane = laneStart; lane < laneEnd; ++lane) + if (plan.coveredLanes[lane] != 0) return true; return false; } void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { - plan.coveredRanges.push_back({laneStart, laneCount}); + assert(laneCount != 0 && "cannot cover an empty whole-batch range"); + assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds"); + + for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) { + if (plan.coveredLanes[lane] != 0) + continue; + plan.coveredLanes[lane] = 1; + ++plan.coveredLaneCount; + } +} + +bool localLaneRangeOverlaps(ArrayRef covered, uint32_t laneStart, uint32_t laneCount) { + if (laneCount == 0) + return false; + if (laneStart >= covered.size()) + return false; + + uint32_t laneEnd = std::min(laneStart + laneCount, covered.size()); + for (uint32_t lane = laneStart; lane < laneEnd; ++lane) + if (covered[lane] != 0) + return true; + return false; +} + +void markLocalLaneRangeCovered(MutableArrayRef covered, uint32_t laneStart, uint32_t laneCount) { + assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds"); + for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) + covered[lane] = 1; } LogicalResult @@ -3118,13 +3248,14 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, WholeBatchAssemblyPlan& plan) { - for (PackedScalarRunValue& run : state.availableValues.getPackedScalarRuns()) { - if (run.targetClass != targetClass.id) - continue; - if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) - continue; + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - SmallVector runRanges; + for (size_t runIndex : runIndices) { + PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); + + SmallVector runKeys; + SmallVector runCoveredLanes(plan.batchLaneCount, 0); for (const PackedScalarRunSlot& slot : run.slots) { for (ProducerKey fragmentKey : slot.keys) { @@ -3134,23 +3265,24 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, if (fragmentKey.instance.laneCount == 0) return failure(); - if (wholeBatchRangeOverlaps(plan.coveredRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) return failure(); - if (wholeBatchRangeOverlaps(runRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) return failure(); - runRanges.push_back({fragmentKey.instance.laneStart, fragmentKey.instance.laneCount}); + markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount); + runKeys.push_back(fragmentKey); } } - if (runRanges.empty()) + if (runKeys.empty()) continue; plan.packedRuns.push_back(&run); - for (WholeBatchAssemblyRange range : runRanges) - recordWholeBatchCoverage(plan, range.laneStart, range.laneCount); + for (ProducerKey runKey : runKeys) + recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount); } return success(); @@ -3161,44 +3293,77 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, SpatComputeBatch batch, ProducerKey key, WholeBatchAssemblyPlan& plan) { + struct CandidateFragment { + ProducerKey key; + Value value; + }; + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); - uint32_t lane = 0; + if (plan.coveredLaneCount == plan.batchLaneCount) { + return success(); + } - while (lane < batchLaneCount) { - if (wholeBatchLaneCovered(plan.coveredRanges, lane)) { - ++lane; + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); + ArrayRef indexedFragments = + state.availableValues.getExactFragmentsForWholeBatch(lookupKey); + + SmallVector candidates; + candidates.reserve(indexedFragments.size()); + for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) { + ProducerKey candidateKey = record.key; + if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex + || candidateKey.instance.laneCount == 0) continue; + if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount)) + continue; + + auto fragmentType = dyn_cast(record.value.getType()); + if (!fragmentType) + continue; + + int64_t expectedRows = plan.rowsPerLane * static_cast(candidateKey.instance.laneCount); + if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) + continue; + + candidates.push_back({candidateKey, record.value}); + } + + llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) { + if (lhs.key.instance.laneStart != rhs.key.instance.laneStart) + return lhs.key.instance.laneStart < rhs.key.instance.laneStart; + return lhs.key.instance.laneCount > rhs.key.instance.laneCount; + }); + + size_t candidateCursor = 0; + uint32_t lane = 0; + while (lane < batchLaneCount) { + while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) { + ++lane; } - bool foundFragment = false; - - for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) { - if (wholeBatchRangeOverlaps(plan.coveredRanges, lane, laneCount)) - continue; - - ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex); - std::optional fragment = state.availableValues.lookupExact(candidate, targetClass.id); - if (!fragment) - continue; - - auto fragmentType = dyn_cast(fragment->getType()); - if (!fragmentType) - return failure(); - - int64_t expectedRows = plan.rowsPerLane * static_cast(laneCount); - if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) - return failure(); - - plan.directFragments.push_back({candidate, *fragment}); - recordWholeBatchCoverage(plan, lane, laneCount); - - lane += laneCount; - foundFragment = true; + if (lane >= batchLaneCount) break; + + while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane) + ++candidateCursor; + + size_t candidateIndex = candidateCursor; + const CandidateFragment* best = nullptr; + while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) { + const CandidateFragment& candidate = candidates[candidateIndex]; + if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) { + best = &candidate; + break; + } + ++candidateIndex; } - if (!foundFragment) + if (!best) return failure(); + + plan.directFragments.push_back({best->key, best->value}); + recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount); + lane += best->key.instance.laneCount; } return success(); @@ -3291,11 +3456,11 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, } for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { - std::optional slotKey = getContiguousProducerRangeForKeys(slot.keys); - if (!slotKey) + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); + if (!contiguousKey) return failure(); groupIt->slotIndices.push_back(slotIndex); - groupIt->outputOffsets.push_back(static_cast(slotKey->instance.laneStart) * plan.rowsPerLane); + groupIt->outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); } } @@ -3409,10 +3574,15 @@ FailureOr buildWholeBatchAssemblyPlan(MaterializerState& WholeBatchAssemblyPlan plan; plan.resultType = resultTensorType; plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast(batchLaneCount); + plan.batchLaneCount = batchLaneCount; + plan.coveredLanes.assign(batchLaneCount, 0); if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) return failure(); + if (plan.coveredLaneCount == plan.batchLaneCount) + return plan; + if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) return failure(); @@ -4181,7 +4351,6 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta auto sourceBatch = cast(sourceOp); SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); SmallVector initValues; - for (size_t resultIndex : group.resultIndices) { if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); @@ -4197,7 +4366,6 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta SmallVector logicalLanes; logicalLanes.reserve(run.size()); - for (const MaterializationRunSlot& slot : run) { if (slot.peers.size() != 1) return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); @@ -4215,34 +4383,34 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange(initValues), - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange(initValues), + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - run.front().peers.front(), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); - if (failed(produced)) - return failure(); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + run.front().peers.front(), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); + if (failed(produced)) + return failure(); - yielded.reserve(produced->size()); - for (auto [outputIndex, output] : llvm::enumerate(*produced)) { - auto fragmentType = cast(output.getType()); - Value acc = iterArgs[outputIndex]; - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); - yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); - } - return success(); - }); + yielded.reserve(produced->size()); + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = cast(output.getType()); + Value acc = iterArgs[outputIndex]; + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); + } + return success(); + }); if (failed(loop)) return failure(); @@ -4466,14 +4634,14 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, } bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { - for (const auto& [key, consumers] : state.sameClassConsumers) { - if (!consumers.contains(classId)) - continue; - if (!sameProducerResult(key, producerKey)) - continue; - if (containsProducerKey(key, producerKey) || containsProducerKey(producerKey, key)) + SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId}; + auto it = state.sameClassConsumerIndex.find(lookupKey); + if (it == state.sameClassConsumerIndex.end()) + return false; + + for (ProducerKey existing : it->second) + if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing)) return true; - } return false; } @@ -4488,6 +4656,7 @@ bool canCompactBatchClassRun(MaterializerState& state, ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { + (void) ignored; for (const MaterializationRunSlot& slot : run) { if (slot.peers.empty()) return false; @@ -4533,7 +4702,8 @@ Value createBatchClassRunSourceLane(MaterializerState& state, SmallVector sourceLanes; sourceLanes.reserve(run.size() * targetClass.cpus.size()); - for (const MaterializationRunSlot& slot : run) { + for (auto [runSlotIndex, slot] : llvm::enumerate(run)) { + (void) runSlotIndex; assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); for (const ComputeInstance& peer : slot.peers) sourceLanes.push_back(peer.laneStart); @@ -4577,7 +4747,6 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, plan.messages.targetCoreIds.reserve(messageCount); for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { - (void) slotIndex; for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); if (failed(checkedSourceCpu)) @@ -4590,6 +4759,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, return failure(); plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); } + (void) slotIndex; } plans.push_back(std::move(plan)); @@ -4773,7 +4943,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, return success(); } -LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) { +LogicalResult materializeInstanceSlot(MaterializerState& state, + const ComputeInstance& instance) { auto cpuIt = state.schedule.computeToCpuMap.find(instance); if (cpuIt == state.schedule.computeToCpuMap.end()) return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); @@ -4794,8 +4965,7 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns return success(); if (isa(instance.op)) { - FailureOr run = - collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); + FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); if (succeeded(run)) { if (!targetClass.isBatch) @@ -4924,6 +5094,7 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch return failure(); LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); + (void) _; return success(); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 53c85e0..d5015cc 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1,6 +1,5 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" @@ -14,20 +13,14 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" #include -#include #include #include -#include #include #include #include -#include #include #include @@ -51,83 +44,6 @@ using SpatCompute = spatial::SpatCompute; using SpatComputeBatch = spatial::SpatComputeBatch; using spatial::getProducerValueRef; -bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; } - -class ScopedMergePhaseTimer { -public: - explicit ScopedMergePhaseTimer(StringRef phaseName) - : enabled(isMergeProfilingEnabled()), phase(phaseName.str()) { - if (enabled) - start = std::chrono::steady_clock::now(); - } - - ~ScopedMergePhaseTimer() { - if (!enabled) - return; - auto elapsed = std::chrono::steady_clock::now() - start; - double millis = std::chrono::duration(elapsed).count(); - llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n"; - } - -private: - bool enabled = false; - std::string phase; - std::chrono::steady_clock::time_point start; -}; - -struct MergeIrCounts { - uint64_t topLevelComputeCount = 0; - uint64_t topLevelComputeBatchCount = 0; - uint64_t scalarChannelSendCount = 0; - uint64_t scalarChannelReceiveCount = 0; - uint64_t wvmmCount = 0; - uint64_t vaddCount = 0; - uint64_t scfForCount = 0; -}; - -MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) { - MergeIrCounts counts; - - auto countComputeBodyOps = [&](Operation* op) { - op->walk([&](Operation* nestedOp) { - if (isa(nestedOp)) - ++counts.scalarChannelSendCount; - else if (isa(nestedOp)) - ++counts.scalarChannelReceiveCount; - else if (isa(nestedOp)) - ++counts.wvmmCount; - else if (isa(nestedOp)) - ++counts.vaddCount; - else if (isa(nestedOp)) - ++counts.scfForCount; - }); - }; - - for (auto compute : funcOp.getOps()) { - ++counts.topLevelComputeCount; - countComputeBodyOps(compute.getOperation()); - } - - for (auto batch : funcOp.getOps()) { - ++counts.topLevelComputeBatchCount; - countComputeBodyOps(batch.getOperation()); - } - - return counts; -} - -void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) { - if (!isMergeProfilingEnabled()) - return; - - MergeIrCounts counts = collectMergeIrCounts(funcOp); - llvm::errs() << "[merge-profile] " << phaseName << " counts:" - << " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount - << " scalar_send=" << counts.scalarChannelSendCount - << " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount - << " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n"; -} - static std::optional getComputeCoreId(SpatCompute compute) { if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id"); @@ -138,16 +54,6 @@ static std::optional getComputeCoreId(SpatCompute compute) { return std::nullopt; } -struct ComputeMotifInfo { - uint64_t instructionCount = 0; - uint64_t weightedVmmCount = 0; -}; - -void appendUnique(SmallVector& values, size_t value) { - if (!llvm::is_contained(values, value)) - values.push_back(value); -} - bool isTrivialSerialMergeCandidate(SpatCompute compute) { if (!compute->hasOneUse()) return false; @@ -266,212 +172,6 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) { } } -void emitMotifProfile(func::FuncOp funcOp) { - if (!std::getenv("DCP_MOTIF_PROFILE")) - return; - - SmallVector computes(funcOp.getOps()); - DenseMap computeToIndex; - computeToIndex.reserve(computes.size()); - for (auto [index, compute] : llvm::enumerate(computes)) - computeToIndex[compute] = index; - - SmallVector computeInfos(computes.size()); - SmallVector> parents(computes.size()); - SmallVector> children(computes.size()); - - uint64_t weightedVmmNodeCount = 0; - uint64_t weightedVmmOpCount = 0; - uint64_t edgeCount = 0; - - for (auto [index, compute] : llvm::enumerate(computes)) { - ComputeMotifInfo& info = computeInfos[index]; - info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody()); - compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; }); - if (info.weightedVmmCount > 0) { - weightedVmmNodeCount++; - weightedVmmOpCount += info.weightedVmmCount; - } - - for (Value input : compute.getInputs()) { - auto parent = dyn_cast(input.getDefiningOp()); - if (!parent || parent == compute) - continue; - auto parentIt = computeToIndex.find(parent); - if (parentIt == computeToIndex.end()) - continue; - - size_t parentIndex = parentIt->second; - size_t oldParentCount = parents[index].size(); - appendUnique(parents[index], parentIndex); - if (parents[index].size() != oldParentCount) { - appendUnique(children[parentIndex], index); - edgeCount++; - } - } - } - - uint64_t maxFanIn = 0; - uint64_t maxFanOut = 0; - uint64_t fanIn16 = 0; - uint64_t fanIn64 = 0; - uint64_t fanIn256 = 0; - uint64_t fanOut16 = 0; - uint64_t fanOut64 = 0; - uint64_t fanOut256 = 0; - for (size_t index = 0; index < computes.size(); ++index) { - uint64_t fanIn = parents[index].size(); - uint64_t fanOut = children[index].size(); - maxFanIn = std::max(maxFanIn, fanIn); - maxFanOut = std::max(maxFanOut, fanOut); - fanIn16 += fanIn >= 16; - fanIn64 += fanIn >= 64; - fanIn256 += fanIn >= 256; - fanOut16 += fanOut >= 16; - fanOut64 += fanOut >= 64; - fanOut256 += fanOut >= 256; - } - - uint64_t serialChainCount = 0; - uint64_t serialChainNodeCount = 0; - uint64_t maxSerialChain = 0; - for (size_t index = 0; index < computes.size(); ++index) { - if (parents[index].size() == 1 && children[parents[index][0]].size() == 1) - continue; - - uint64_t chainLength = 1; - size_t current = index; - while (children[current].size() == 1) { - size_t child = children[current][0]; - if (parents[child].size() != 1) - break; - chainLength++; - current = child; - } - - if (chainLength >= 2) { - serialChainCount++; - serialChainNodeCount += chainLength; - maxSerialChain = std::max(maxSerialChain, chainLength); - } - } - - SmallVector incomingEdgeCount; - incomingEdgeCount.reserve(parents.size()); - for (ArrayRef parentList : parents) - incomingEdgeCount.push_back(parentList.size()); - - SmallVector level(computes.size(), 0); - SmallVector readyNodes; - readyNodes.reserve(computes.size()); - for (size_t index = 0; index < computes.size(); ++index) - if (incomingEdgeCount[index] == 0) - readyNodes.push_back(index); - - size_t readyIndex = 0; - while (readyIndex != readyNodes.size()) { - size_t current = readyNodes[readyIndex++]; - for (size_t child : children[current]) { - level[child] = std::max(level[child], level[current] + 1); - assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow"); - incomingEdgeCount[child]--; - if (incomingEdgeCount[child] == 0) - readyNodes.push_back(child); - } - } - - SmallVector weightedVmmNodesByLevel; - for (size_t index = 0; index < computes.size(); ++index) { - if (computeInfos[index].weightedVmmCount == 0) - continue; - if (weightedVmmNodesByLevel.size() <= level[index]) - weightedVmmNodesByLevel.resize(level[index] + 1, 0); - weightedVmmNodesByLevel[level[index]]++; - } - - uint64_t maxWeightedVmmLevel = 0; - uint64_t wideWeightedVmmLevels64 = 0; - uint64_t wideWeightedVmmLevels256 = 0; - for (uint64_t count : weightedVmmNodesByLevel) { - maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count); - wideWeightedVmmLevels64 += count >= 64; - wideWeightedVmmLevels256 += count >= 256; - } - - using ShapeKey = std::tuple; - SmallVector weightedVmmShapeKeys; - for (auto [index, compute] : llvm::enumerate(computes)) { - const ComputeMotifInfo& info = computeInfos[index]; - if (info.weightedVmmCount == 0) - continue; - weightedVmmShapeKeys.push_back({info.instructionCount, - info.weightedVmmCount, - static_cast(compute.getWeights().size()), - static_cast(compute.getInputs().size()), - static_cast(parents[index].size()), - static_cast(children[index].size())}); - } - - llvm::sort(weightedVmmShapeKeys); - SmallVector> weightedVmmShapeCounts; - for (size_t index = 0; index < weightedVmmShapeKeys.size();) { - size_t next = index + 1; - while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index]) - next++; - weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]}); - index = next; - } - llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) { - if (lhs.first != rhs.first) - return lhs.first > rhs.first; - return lhs.second < rhs.second; - }); - - llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} " - "serialChains={4} serialChainNodes={5} maxSerialChain={6} " - "maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} " - "fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n", - computes.size(), - edgeCount, - weightedVmmNodeCount, - weightedVmmOpCount, - serialChainCount, - serialChainNodeCount, - maxSerialChain, - maxFanIn, - maxFanOut, - fanIn16, - fanIn64, - fanIn256, - fanOut16, - fanOut64, - fanOut256, - readyNodes.size()); - - llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} " - "shapeGroups={4}\n", - weightedVmmNodesByLevel.size(), - maxWeightedVmmLevel, - wideWeightedVmmLevels64, - wideWeightedVmmLevels256, - weightedVmmShapeCounts.size()); - - for (size_t rank = 0, end = std::min(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) { - auto [count, shape] = weightedVmmShapeCounts[rank]; - auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape; - llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} " - "weights={4} inputs={5} fanIn={6} fanOut={7}\n", - rank, - count, - insts, - vmmOps, - weights, - inputs, - fanIn, - fanOut); - } -} - void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) { std::fstream file = openReportFile(name); if (!file.is_open()) @@ -628,44 +328,27 @@ public: void runOnOperation() override { func::FuncOp func = getOperation(); - { - ScopedMergePhaseTimer timer("trivial-serial-merge"); - mergeTriviallyConnectedComputes(func); - } - if (std::getenv("DCP_MOTIF_PROFILE")) - emitMotifProfile(func); + mergeTriviallyConnectedComputes(func); const spatial::MergeScheduleResult* analysisResult = nullptr; - { - ScopedMergePhaseTimer timer("scheduling-analysis"); - analysisResult = &getAnalysis().getResult(); - } - { - ScopedMergePhaseTimer timer("schedule-materialization"); - if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) { - signalPassFailure(); - return; - } + analysisResult = &getAnalysis().getResult(); + if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) { + signalPassFailure(); + return; } - emitMergeIrCounts("after-materialization", func); - - { - ScopedMergePhaseTimer timer("cleanup-topological-sort-report"); - if (!sortTopologically(&func.getBody().front())) { - func.emitOpError("failed to topologically order merged Spatial IR"); - signalPassFailure(); - return; - } - if (failed(verifySpatialCommunicationInvariants(func))) { - func.emitOpError("merged Spatial communication invariant verification failed"); - signalPassFailure(); - return; - } - emitMergeIrCounts("final-post-merge", func); - dumpModule(cast(func->getParentOp()), "spatial1_merged"); - generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); + if (!sortTopologically(&func.getBody().front())) { + func.emitOpError("failed to topologically order merged Spatial IR"); + signalPassFailure(); + return; } + if (failed(verifySpatialCommunicationInvariants(func))) { + func.emitOpError("merged Spatial communication invariant verification failed"); + signalPassFailure(); + return; + } + dumpModule(cast(func->getParentOp()), "spatial1_merged"); + generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); } };