much faster MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-05 18:22:59 +02:00
parent 8ddbbcecfa
commit aec80529ca
2 changed files with 338 additions and 484 deletions
@@ -11,6 +11,7 @@
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@@ -117,6 +118,66 @@ struct ProducerKeyInfo {
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
}; };
struct SameClassConsumerLookupKey {
Operation* sourceOp = nullptr;
size_t resultIndex = 0;
ClassId classId = 0;
bool operator==(const SameClassConsumerLookupKey& other) const {
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
}
};
struct SameClassConsumerLookupKeyInfo {
static SameClassConsumerLookupKey getEmptyKey() {
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static SameClassConsumerLookupKey getTombstoneKey() {
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
}
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
return lhs == rhs;
}
};
struct WholeBatchAssemblyLookupKey {
Operation* sourceOp = nullptr;
size_t resultIndex = 0;
ClassId classId = 0;
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
}
};
struct WholeBatchAssemblyLookupKeyInfo {
static WholeBatchAssemblyLookupKey getEmptyKey() {
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static WholeBatchAssemblyLookupKey getTombstoneKey() {
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
}
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
return lhs == rhs;
}
};
using ClassSlotKey = std::pair<ClassId, SlotId>; using ClassSlotKey = std::pair<ClassId, SlotId>;
struct MaterializedClass { struct MaterializedClass {
@@ -270,9 +331,36 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
class AvailableValueStore { class AvailableValueStore {
public: public:
void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; } struct ExactBatchFragmentRecord {
ProducerKey key;
Value value;
};
void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); } void record(ProducerKey key, ClassId classId, Value value) {
exactValues[key][classId] = value;
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
if (!batch || key.instance.laneCount == 0)
return;
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
for (ExactBatchFragmentRecord& record : bucket) {
if (!(record.key == key))
continue;
record.value = value;
return;
}
bucket.push_back({key, value});
}
void recordPackedRun(PackedScalarRunValue run) {
size_t runIndex = packedScalarRuns.size();
packedScalarRuns.push_back(std::move(run));
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
}
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const; std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const;
@@ -280,7 +368,21 @@ public:
std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId); std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
SmallVectorImpl<PackedScalarRunValue>& getPackedScalarRuns() { return packedScalarRuns; } ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
auto it = packedRunsByProducerResultClass.find(key);
if (it == packedRunsByProducerResultClass.end())
return {};
return it->second;
}
ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
auto it = exactBatchFragmentsByProducerResultClass.find(key);
if (it == exactBatchFragmentsByProducerResultClass.end())
return {};
return it->second;
}
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
private: private:
std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
@@ -288,6 +390,10 @@ private:
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues; DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues;
SmallVector<PackedScalarRunValue, 8> packedScalarRuns; SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns; SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<ExactBatchFragmentRecord, 16>, WholeBatchAssemblyLookupKeyInfo>
exactBatchFragmentsByProducerResultClass;
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
packedRunsByProducerResultClass;
}; };
struct MaterializerState { struct MaterializerState {
@@ -296,7 +402,6 @@ struct MaterializerState {
IRRewriter rewriter; IRRewriter rewriter;
OperationFolder constantFolder; OperationFolder constantFolder;
int64_t& nextChannelId; int64_t& nextChannelId;
SmallVector<MaterializedClass, 8> classes; SmallVector<MaterializedClass, 8> classes;
DenseMap<CpuId, ClassId> cpuToClass; DenseMap<CpuId, ClassId> cpuToClass;
DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu; DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
@@ -305,7 +410,8 @@ struct MaterializerState {
DenseSet<ClassSlotKey> materializedLogicalSlots; DenseSet<ClassSlotKey> materializedLogicalSlots;
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses; DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers; DenseMap<SameClassConsumerLookupKey, SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
sameClassConsumerIndex;
DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches; DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches;
DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs; DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
DenseMap<Value, bool> liveExternalUseCache; DenseMap<Value, bool> liveExternalUseCache;
@@ -317,7 +423,9 @@ struct MaterializerState {
DenseMap<Value, Value> hostReplacements; DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps; DenseSet<Operation*> oldComputeOps;
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) MaterializerState(func::FuncOp func,
const MergeScheduleResult& schedule,
int64_t& nextChannelId)
: func(func), : func(func),
schedule(schedule), schedule(schedule),
rewriter(func.getContext()), rewriter(func.getContext()),
@@ -428,6 +536,14 @@ std::optional<ProducerKey> getContiguousProducerRangeForKeys(ArrayRef<ProducerKe
return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex);
} }
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) {
return {sourceOp, resultIndex, classId};
}
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) {
return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId);
}
FailureOr<RankedTensorType> getPackedBatchTensorType(Type laneType, size_t laneCount) { FailureOr<RankedTensorType> getPackedBatchTensorType(Type laneType, size_t laneCount) {
auto tensorType = dyn_cast<RankedTensorType>(laneType); auto tensorType = dyn_cast<RankedTensorType>(laneType);
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
@@ -1172,14 +1288,12 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) {
for (PackedScalarRunValue& run : packedScalarRuns) { for (PackedScalarRunValue& run : packedScalarRuns) {
if (run.targetClass != classId) if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue; continue;
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys); std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
if (!slotKey || !containsProducerKey(*slotKey, key)) if (!contiguousKey || !containsProducerKey(*contiguousKey, key))
continue; continue;
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
@@ -1197,12 +1311,13 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
Value slotPacked = Value slotPacked =
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
if (*slotKey == key) { if (*contiguousKey == key) {
record(key, classId, slotPacked); record(key, classId, slotPacked);
return slotPacked; return slotPacked;
} }
std::optional<Value> sliced = extractPackedProducerSlice(state, materializedClass, *slotKey, slotPacked, key); std::optional<Value> sliced =
extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key);
if (!sliced) if (!sliced)
return std::nullopt; return std::nullopt;
@@ -1216,57 +1331,45 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) {
for (IndexedBatchRunValue& run : indexedBatchRuns) { for (IndexedBatchRunValue& run : indexedBatchRuns) {
if (run.targetClass != classId) if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue; continue;
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) for (const PackedScalarRunSlot& slot : run.slots) {
continue; if (!llvm::is_contained(slot.keys, key))
continue;
for (const PackedScalarRunSlot& slot : run.slots) return &run;
if (llvm::is_contained(slot.keys, key)) }
return &run;
} }
return nullptr; return nullptr;
} }
std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) {
if (std::optional<Value> exact = lookupExact(key, classId))
if (std::optional<Value> exact = lookupExact(key, classId)) {
return exact; return exact;
}
if (std::optional<Value> packedRunValue = lookupPackedRun(state, key, classId)) if (std::optional<Value> packedRunValue = lookupPackedRun(state, key, classId))
return packedRunValue; return packedRunValue;
MaterializedClass& materializedClass = state.classes[classId]; MaterializedClass& materializedClass = state.classes[classId];
ProducerKey containingKey; for (const auto& [candidateKey, classValues] : exactValues) {
Value containingValue; if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key))
bool foundContainingValue = false;
for (auto& entry : exactValues) {
ProducerKey candidateKey = entry.first;
if (!containsProducerKey(candidateKey, key))
continue; continue;
auto valueIt = entry.second.find(classId); auto valueIt = classValues.find(classId);
if (valueIt == entry.second.end()) if (valueIt == classValues.end())
continue; continue;
containingKey = candidateKey; std::optional<Value> slice =
containingValue = valueIt->second; extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key);
foundContainingValue = true; if (!slice)
break; return std::nullopt;
record(key, classId, *slice);
return *slice;
} }
return std::nullopt;
if (!foundContainingValue)
return std::nullopt;
std::optional<Value> slice =
extractPackedProducerSlice(state, materializedClass, containingKey, containingValue, key);
if (!slice)
return std::nullopt;
record(key, classId, *slice);
return *slice;
} }
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) { Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
@@ -1389,13 +1492,13 @@ Value createIndexedIndexValue(MaterializerState& state,
bool allowExhaustiveTiledSearch) { bool allowExhaustiveTiledSearch) {
assert(!values.empty() && "expected at least one indexed value"); assert(!values.empty() && "expected at least one indexed value");
if (allEqual(values)) if (allEqual(values)) {
return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); return getOrCreateIndexConstant(state.constantFolder, anchor, values.front());
}
if (std::optional<IndexedIndexPattern> pattern = if (std::optional<IndexedIndexPattern> pattern =
getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch))
return createAffineIndexValue(state, *pattern, index, loc); return createAffineIndexValue(state, *pattern, index, loc);
Value table = createIndexTensorConstant(state, anchor, values); Value table = createIndexTensorConstant(state, anchor, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
} }
@@ -1578,7 +1681,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
if (sourceClass == targetClass) { if (sourceClass == targetClass) {
state.sameClassConsumers[producerKey].insert(targetClass); SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass};
SmallVector<ProducerKey, 4>& bucket = state.sameClassConsumerIndex[lookupKey];
if (!llvm::is_contained(bucket, producerKey))
bucket.push_back(producerKey);
continue; continue;
} }
@@ -2899,11 +3005,6 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return success(); return success();
} }
struct WholeBatchAssemblyRange {
uint32_t laneStart = 0;
uint32_t laneCount = 0;
};
struct DirectWholeBatchFragment { struct DirectWholeBatchFragment {
ProducerKey key; ProducerKey key;
Value fragment; Value fragment;
@@ -2933,31 +3034,60 @@ struct WholeBatchFragmentGroup {
struct WholeBatchAssemblyPlan { struct WholeBatchAssemblyPlan {
RankedTensorType resultType; RankedTensorType resultType;
int64_t rowsPerLane = 0; int64_t rowsPerLane = 0;
uint32_t batchLaneCount = 0;
uint32_t coveredLaneCount = 0;
SmallVector<WholeBatchAssemblyRange, 16> coveredRanges; SmallVector<uint8_t, 64> coveredLanes;
SmallVector<PackedScalarRunValue*, 8> packedRuns; SmallVector<PackedScalarRunValue*, 8> packedRuns;
SmallVector<DirectWholeBatchFragment, 16> directFragments; SmallVector<DirectWholeBatchFragment, 16> directFragments;
}; };
bool wholeBatchRangeOverlaps(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t laneStart, uint32_t laneCount) { bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) {
uint32_t laneEnd = laneStart + laneCount; return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0;
for (WholeBatchAssemblyRange range : ranges) {
uint32_t rangeEnd = range.laneStart + range.laneCount;
if (laneStart < rangeEnd && range.laneStart < laneEnd)
return true;
}
return false;
} }
bool wholeBatchLaneCovered(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t lane) { bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
for (WholeBatchAssemblyRange range : ranges) if (laneCount == 0)
if (range.laneStart <= lane && lane < range.laneStart + range.laneCount) return false;
if (laneStart >= plan.coveredLanes.size())
return false;
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, plan.coveredLanes.size());
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
if (plan.coveredLanes[lane] != 0)
return true; return true;
return false; return false;
} }
void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
plan.coveredRanges.push_back({laneStart, laneCount}); assert(laneCount != 0 && "cannot cover an empty whole-batch range");
assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds");
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) {
if (plan.coveredLanes[lane] != 0)
continue;
plan.coveredLanes[lane] = 1;
++plan.coveredLaneCount;
}
}
bool localLaneRangeOverlaps(ArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
if (laneCount == 0)
return false;
if (laneStart >= covered.size())
return false;
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, covered.size());
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
if (covered[lane] != 0)
return true;
return false;
}
void markLocalLaneRangeCovered(MutableArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds");
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane)
covered[lane] = 1;
} }
LogicalResult LogicalResult
@@ -3118,13 +3248,14 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
MaterializedClass& targetClass, MaterializedClass& targetClass,
ProducerKey key, ProducerKey key,
WholeBatchAssemblyPlan& plan) { WholeBatchAssemblyPlan& plan) {
for (PackedScalarRunValue& run : state.availableValues.getPackedScalarRuns()) { WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
if (run.targetClass != targetClass.id) ArrayRef<size_t> runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey);
continue;
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
SmallVector<WholeBatchAssemblyRange, 16> runRanges; for (size_t runIndex : runIndices) {
PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex);
SmallVector<ProducerKey, 16> runKeys;
SmallVector<uint8_t, 64> runCoveredLanes(plan.batchLaneCount, 0);
for (const PackedScalarRunSlot& slot : run.slots) { for (const PackedScalarRunSlot& slot : run.slots) {
for (ProducerKey fragmentKey : slot.keys) { for (ProducerKey fragmentKey : slot.keys) {
@@ -3134,23 +3265,24 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
if (fragmentKey.instance.laneCount == 0) if (fragmentKey.instance.laneCount == 0)
return failure(); return failure();
if (wholeBatchRangeOverlaps(plan.coveredRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
return failure(); return failure();
if (wholeBatchRangeOverlaps(runRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
return failure(); return failure();
runRanges.push_back({fragmentKey.instance.laneStart, fragmentKey.instance.laneCount}); markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount);
runKeys.push_back(fragmentKey);
} }
} }
if (runRanges.empty()) if (runKeys.empty())
continue; continue;
plan.packedRuns.push_back(&run); plan.packedRuns.push_back(&run);
for (WholeBatchAssemblyRange range : runRanges) for (ProducerKey runKey : runKeys)
recordWholeBatchCoverage(plan, range.laneStart, range.laneCount); recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount);
} }
return success(); return success();
@@ -3161,44 +3293,77 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state,
SpatComputeBatch batch, SpatComputeBatch batch,
ProducerKey key, ProducerKey key,
WholeBatchAssemblyPlan& plan) { WholeBatchAssemblyPlan& plan) {
struct CandidateFragment {
ProducerKey key;
Value value;
};
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount()); uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
uint32_t lane = 0; if (plan.coveredLaneCount == plan.batchLaneCount) {
return success();
}
while (lane < batchLaneCount) { WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
if (wholeBatchLaneCovered(plan.coveredRanges, lane)) { ArrayRef<AvailableValueStore::ExactBatchFragmentRecord> indexedFragments =
++lane; state.availableValues.getExactFragmentsForWholeBatch(lookupKey);
SmallVector<CandidateFragment, 16> candidates;
candidates.reserve(indexedFragments.size());
for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) {
ProducerKey candidateKey = record.key;
if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex
|| candidateKey.instance.laneCount == 0)
continue; continue;
if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount))
continue;
auto fragmentType = dyn_cast<RankedTensorType>(record.value.getType());
if (!fragmentType)
continue;
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(candidateKey.instance.laneCount);
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
continue;
candidates.push_back({candidateKey, record.value});
}
llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) {
if (lhs.key.instance.laneStart != rhs.key.instance.laneStart)
return lhs.key.instance.laneStart < rhs.key.instance.laneStart;
return lhs.key.instance.laneCount > rhs.key.instance.laneCount;
});
size_t candidateCursor = 0;
uint32_t lane = 0;
while (lane < batchLaneCount) {
while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) {
++lane;
} }
bool foundFragment = false; if (lane >= batchLaneCount)
for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) {
if (wholeBatchRangeOverlaps(plan.coveredRanges, lane, laneCount))
continue;
ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex);
std::optional<Value> fragment = state.availableValues.lookupExact(candidate, targetClass.id);
if (!fragment)
continue;
auto fragmentType = dyn_cast<RankedTensorType>(fragment->getType());
if (!fragmentType)
return failure();
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(laneCount);
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
return failure();
plan.directFragments.push_back({candidate, *fragment});
recordWholeBatchCoverage(plan, lane, laneCount);
lane += laneCount;
foundFragment = true;
break; break;
while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane)
++candidateCursor;
size_t candidateIndex = candidateCursor;
const CandidateFragment* best = nullptr;
while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) {
const CandidateFragment& candidate = candidates[candidateIndex];
if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) {
best = &candidate;
break;
}
++candidateIndex;
} }
if (!foundFragment) if (!best)
return failure(); return failure();
plan.directFragments.push_back({best->key, best->value});
recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount);
lane += best->key.instance.laneCount;
} }
return success(); return success();
@@ -3291,11 +3456,11 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
} }
for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) {
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys); std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
if (!slotKey) if (!contiguousKey)
return failure(); return failure();
groupIt->slotIndices.push_back(slotIndex); groupIt->slotIndices.push_back(slotIndex);
groupIt->outputOffsets.push_back(static_cast<int64_t>(slotKey->instance.laneStart) * plan.rowsPerLane); groupIt->outputOffsets.push_back(static_cast<int64_t>(contiguousKey->instance.laneStart) * plan.rowsPerLane);
} }
} }
@@ -3409,10 +3574,15 @@ FailureOr<WholeBatchAssemblyPlan> buildWholeBatchAssemblyPlan(MaterializerState&
WholeBatchAssemblyPlan plan; WholeBatchAssemblyPlan plan;
plan.resultType = resultTensorType; plan.resultType = resultTensorType;
plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast<int64_t>(batchLaneCount); plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast<int64_t>(batchLaneCount);
plan.batchLaneCount = batchLaneCount;
plan.coveredLanes.assign(batchLaneCount, 0);
if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan)))
return failure(); return failure();
if (plan.coveredLaneCount == plan.batchLaneCount)
return plan;
if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan)))
return failure(); return failure();
@@ -4181,7 +4351,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
auto sourceBatch = cast<SpatComputeBatch>(sourceOp); auto sourceBatch = cast<SpatComputeBatch>(sourceOp);
SmallVector<Type, 4>& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); SmallVector<Type, 4>& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch);
SmallVector<Value, 4> initValues; SmallVector<Value, 4> initValues;
for (size_t resultIndex : group.resultIndices) { for (size_t resultIndex : group.resultIndices) {
if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex])
return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run");
@@ -4197,7 +4366,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
SmallVector<int64_t, 8> logicalLanes; SmallVector<int64_t, 8> logicalLanes;
logicalLanes.reserve(run.size()); logicalLanes.reserve(run.size());
for (const MaterializationRunSlot& slot : run) { for (const MaterializationRunSlot& slot : run) {
if (slot.peers.size() != 1) if (slot.peers.size() != 1)
return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot");
@@ -4215,34 +4383,34 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = buildNormalizedScfFor( auto loop = buildNormalizedScfFor(
state.rewriter, state.rewriter,
loc, loc,
lowerBound, lowerBound,
upperBound, upperBound,
step, step,
ValueRange(initValues), ValueRange(initValues),
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) { [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced = FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, cloneBatchBodyForLane(state,
targetClass, targetClass,
run.front().peers.front(), run.front().peers.front(),
sourceLane, sourceLane,
group.resultIndices, group.resultIndices,
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
if (failed(produced)) if (failed(produced))
return failure(); return failure();
yielded.reserve(produced->size()); yielded.reserve(produced->size());
for (auto [outputIndex, output] : llvm::enumerate(*produced)) { for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
auto fragmentType = cast<RankedTensorType>(output.getType()); auto fragmentType = cast<RankedTensorType>(output.getType());
Value acc = iterArgs[outputIndex]; Value acc = iterArgs[outputIndex];
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
} }
return success(); return success();
}); });
if (failed(loop)) if (failed(loop))
return failure(); return failure();
@@ -4466,14 +4634,14 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
} }
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
for (const auto& [key, consumers] : state.sameClassConsumers) { SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId};
if (!consumers.contains(classId)) auto it = state.sameClassConsumerIndex.find(lookupKey);
continue; if (it == state.sameClassConsumerIndex.end())
if (!sameProducerResult(key, producerKey)) return false;
continue;
if (containsProducerKey(key, producerKey) || containsProducerKey(producerKey, key)) for (ProducerKey existing : it->second)
if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing))
return true; return true;
}
return false; return false;
} }
@@ -4488,6 +4656,7 @@ bool canCompactBatchClassRun(MaterializerState& state,
ArrayRef<Value> outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); ArrayRef<Value> outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front());
for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) {
(void) ignored;
for (const MaterializationRunSlot& slot : run) { for (const MaterializationRunSlot& slot : run) {
if (slot.peers.empty()) if (slot.peers.empty())
return false; return false;
@@ -4533,7 +4702,8 @@ Value createBatchClassRunSourceLane(MaterializerState& state,
SmallVector<int64_t, 16> sourceLanes; SmallVector<int64_t, 16> sourceLanes;
sourceLanes.reserve(run.size() * targetClass.cpus.size()); sourceLanes.reserve(run.size() * targetClass.cpus.size());
for (const MaterializationRunSlot& slot : run) { for (auto [runSlotIndex, slot] : llvm::enumerate(run)) {
(void) runSlotIndex;
assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane");
for (const ComputeInstance& peer : slot.peers) for (const ComputeInstance& peer : slot.peers)
sourceLanes.push_back(peer.laneStart); sourceLanes.push_back(peer.laneStart);
@@ -4577,7 +4747,6 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
plan.messages.targetCoreIds.reserve(messageCount); plan.messages.targetCoreIds.reserve(messageCount);
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
(void) slotIndex;
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
if (failed(checkedSourceCpu)) if (failed(checkedSourceCpu))
@@ -4590,6 +4759,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
return failure(); return failure();
plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
} }
(void) slotIndex;
} }
plans.push_back(std::move(plan)); plans.push_back(std::move(plan));
@@ -4773,7 +4943,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
return success(); return success();
} }
LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) { LogicalResult materializeInstanceSlot(MaterializerState& state,
const ComputeInstance& instance) {
auto cpuIt = state.schedule.computeToCpuMap.find(instance); auto cpuIt = state.schedule.computeToCpuMap.find(instance);
if (cpuIt == state.schedule.computeToCpuMap.end()) if (cpuIt == state.schedule.computeToCpuMap.end())
return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance");
@@ -4794,8 +4965,7 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
return success(); return success();
if (isa<SpatComputeBatch>(instance.op)) { if (isa<SpatComputeBatch>(instance.op)) {
FailureOr<MaterializationRun> run = FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
if (succeeded(run)) { if (succeeded(run)) {
if (!targetClass.isBatch) if (!targetClass.isBatch)
@@ -4924,6 +5094,7 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
return failure(); return failure();
LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody());
(void) _;
return success(); return success();
} }
@@ -1,6 +1,5 @@
#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@@ -14,20 +13,14 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm> #include <algorithm>
#include <chrono>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <cstdlib>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
@@ -51,83 +44,6 @@ using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch; using SpatComputeBatch = spatial::SpatComputeBatch;
using spatial::getProducerValueRef; using spatial::getProducerValueRef;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
~ScopedMergePhaseTimer() {
if (!enabled)
return;
auto elapsed = std::chrono::steady_clock::now() - start;
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
}
private:
bool enabled = false;
std::string phase;
std::chrono::steady_clock::time_point start;
};
struct MergeIrCounts {
uint64_t topLevelComputeCount = 0;
uint64_t topLevelComputeBatchCount = 0;
uint64_t scalarChannelSendCount = 0;
uint64_t scalarChannelReceiveCount = 0;
uint64_t wvmmCount = 0;
uint64_t vaddCount = 0;
uint64_t scfForCount = 0;
};
MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
MergeIrCounts counts;
auto countComputeBodyOps = [&](Operation* op) {
op->walk([&](Operation* nestedOp) {
if (isa<spatial::SpatChannelSendOp>(nestedOp))
++counts.scalarChannelSendCount;
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
++counts.scalarChannelReceiveCount;
else if (isa<spatial::SpatVMMOp>(nestedOp))
++counts.wvmmCount;
else if (isa<spatial::SpatVAddOp>(nestedOp))
++counts.vaddCount;
else if (isa<scf::ForOp>(nestedOp))
++counts.scfForCount;
});
};
for (auto compute : funcOp.getOps<SpatCompute>()) {
++counts.topLevelComputeCount;
countComputeBodyOps(compute.getOperation());
}
for (auto batch : funcOp.getOps<SpatComputeBatch>()) {
++counts.topLevelComputeBatchCount;
countComputeBodyOps(batch.getOperation());
}
return counts;
}
void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
if (!isMergeProfilingEnabled())
return;
MergeIrCounts counts = collectMergeIrCounts(funcOp);
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) { if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id"); auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
@@ -138,16 +54,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
return std::nullopt; return std::nullopt;
} }
struct ComputeMotifInfo {
uint64_t instructionCount = 0;
uint64_t weightedVmmCount = 0;
};
void appendUnique(SmallVector<size_t>& values, size_t value) {
if (!llvm::is_contained(values, value))
values.push_back(value);
}
bool isTrivialSerialMergeCandidate(SpatCompute compute) { bool isTrivialSerialMergeCandidate(SpatCompute compute) {
if (!compute->hasOneUse()) if (!compute->hasOneUse())
return false; return false;
@@ -266,212 +172,6 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
} }
} }
void emitMotifProfile(func::FuncOp funcOp) {
if (!std::getenv("DCP_MOTIF_PROFILE"))
return;
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseMap<SpatCompute, size_t> computeToIndex;
computeToIndex.reserve(computes.size());
for (auto [index, compute] : llvm::enumerate(computes))
computeToIndex[compute] = index;
SmallVector<ComputeMotifInfo> computeInfos(computes.size());
SmallVector<SmallVector<size_t>> parents(computes.size());
SmallVector<SmallVector<size_t>> children(computes.size());
uint64_t weightedVmmNodeCount = 0;
uint64_t weightedVmmOpCount = 0;
uint64_t edgeCount = 0;
for (auto [index, compute] : llvm::enumerate(computes)) {
ComputeMotifInfo& info = computeInfos[index];
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
if (info.weightedVmmCount > 0) {
weightedVmmNodeCount++;
weightedVmmOpCount += info.weightedVmmCount;
}
for (Value input : compute.getInputs()) {
auto parent = dyn_cast<SpatCompute>(input.getDefiningOp());
if (!parent || parent == compute)
continue;
auto parentIt = computeToIndex.find(parent);
if (parentIt == computeToIndex.end())
continue;
size_t parentIndex = parentIt->second;
size_t oldParentCount = parents[index].size();
appendUnique(parents[index], parentIndex);
if (parents[index].size() != oldParentCount) {
appendUnique(children[parentIndex], index);
edgeCount++;
}
}
}
uint64_t maxFanIn = 0;
uint64_t maxFanOut = 0;
uint64_t fanIn16 = 0;
uint64_t fanIn64 = 0;
uint64_t fanIn256 = 0;
uint64_t fanOut16 = 0;
uint64_t fanOut64 = 0;
uint64_t fanOut256 = 0;
for (size_t index = 0; index < computes.size(); ++index) {
uint64_t fanIn = parents[index].size();
uint64_t fanOut = children[index].size();
maxFanIn = std::max(maxFanIn, fanIn);
maxFanOut = std::max(maxFanOut, fanOut);
fanIn16 += fanIn >= 16;
fanIn64 += fanIn >= 64;
fanIn256 += fanIn >= 256;
fanOut16 += fanOut >= 16;
fanOut64 += fanOut >= 64;
fanOut256 += fanOut >= 256;
}
uint64_t serialChainCount = 0;
uint64_t serialChainNodeCount = 0;
uint64_t maxSerialChain = 0;
for (size_t index = 0; index < computes.size(); ++index) {
if (parents[index].size() == 1 && children[parents[index][0]].size() == 1)
continue;
uint64_t chainLength = 1;
size_t current = index;
while (children[current].size() == 1) {
size_t child = children[current][0];
if (parents[child].size() != 1)
break;
chainLength++;
current = child;
}
if (chainLength >= 2) {
serialChainCount++;
serialChainNodeCount += chainLength;
maxSerialChain = std::max(maxSerialChain, chainLength);
}
}
SmallVector<size_t> incomingEdgeCount;
incomingEdgeCount.reserve(parents.size());
for (ArrayRef<size_t> parentList : parents)
incomingEdgeCount.push_back(parentList.size());
SmallVector<uint64_t> level(computes.size(), 0);
SmallVector<size_t> readyNodes;
readyNodes.reserve(computes.size());
for (size_t index = 0; index < computes.size(); ++index)
if (incomingEdgeCount[index] == 0)
readyNodes.push_back(index);
size_t readyIndex = 0;
while (readyIndex != readyNodes.size()) {
size_t current = readyNodes[readyIndex++];
for (size_t child : children[current]) {
level[child] = std::max(level[child], level[current] + 1);
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push_back(child);
}
}
SmallVector<uint64_t> weightedVmmNodesByLevel;
for (size_t index = 0; index < computes.size(); ++index) {
if (computeInfos[index].weightedVmmCount == 0)
continue;
if (weightedVmmNodesByLevel.size() <= level[index])
weightedVmmNodesByLevel.resize(level[index] + 1, 0);
weightedVmmNodesByLevel[level[index]]++;
}
uint64_t maxWeightedVmmLevel = 0;
uint64_t wideWeightedVmmLevels64 = 0;
uint64_t wideWeightedVmmLevels256 = 0;
for (uint64_t count : weightedVmmNodesByLevel) {
maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count);
wideWeightedVmmLevels64 += count >= 64;
wideWeightedVmmLevels256 += count >= 256;
}
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
SmallVector<ShapeKey> weightedVmmShapeKeys;
for (auto [index, compute] : llvm::enumerate(computes)) {
const ComputeMotifInfo& info = computeInfos[index];
if (info.weightedVmmCount == 0)
continue;
weightedVmmShapeKeys.push_back({info.instructionCount,
info.weightedVmmCount,
static_cast<uint64_t>(compute.getWeights().size()),
static_cast<uint64_t>(compute.getInputs().size()),
static_cast<uint64_t>(parents[index].size()),
static_cast<uint64_t>(children[index].size())});
}
llvm::sort(weightedVmmShapeKeys);
SmallVector<std::pair<uint64_t, ShapeKey>> weightedVmmShapeCounts;
for (size_t index = 0; index < weightedVmmShapeKeys.size();) {
size_t next = index + 1;
while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index])
next++;
weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]});
index = next;
}
llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first)
return lhs.first > rhs.first;
return lhs.second < rhs.second;
});
llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} "
"serialChains={4} serialChainNodes={5} maxSerialChain={6} "
"maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} "
"fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n",
computes.size(),
edgeCount,
weightedVmmNodeCount,
weightedVmmOpCount,
serialChainCount,
serialChainNodeCount,
maxSerialChain,
maxFanIn,
maxFanOut,
fanIn16,
fanIn64,
fanIn256,
fanOut16,
fanOut64,
fanOut256,
readyNodes.size());
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} "
"shapeGroups={4}\n",
weightedVmmNodesByLevel.size(),
maxWeightedVmmLevel,
wideWeightedVmmLevels64,
wideWeightedVmmLevels256,
weightedVmmShapeCounts.size());
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
auto [count, shape] = weightedVmmShapeCounts[rank];
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
rank,
count,
insts,
vmmOps,
weights,
inputs,
fanIn,
fanOut);
}
}
void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) { void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) {
std::fstream file = openReportFile(name); std::fstream file = openReportFile(name);
if (!file.is_open()) if (!file.is_open())
@@ -628,44 +328,27 @@ public:
void runOnOperation() override { void runOnOperation() override {
func::FuncOp func = getOperation(); func::FuncOp func = getOperation();
{ mergeTriviallyConnectedComputes(func);
ScopedMergePhaseTimer timer("trivial-serial-merge");
mergeTriviallyConnectedComputes(func);
}
if (std::getenv("DCP_MOTIF_PROFILE"))
emitMotifProfile(func);
const spatial::MergeScheduleResult* analysisResult = nullptr; const spatial::MergeScheduleResult* analysisResult = nullptr;
{ analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
ScopedMergePhaseTimer timer("scheduling-analysis"); if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult(); signalPassFailure();
} return;
{
ScopedMergePhaseTimer timer("schedule-materialization");
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
signalPassFailure();
return;
}
} }
emitMergeIrCounts("after-materialization", func); if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
{ signalPassFailure();
ScopedMergePhaseTimer timer("cleanup-topological-sort-report"); return;
if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure();
return;
}
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
signalPassFailure();
return;
}
emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
} }
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
signalPassFailure();
return;
}
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
} }
}; };