much faster MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-05 18:22:59 +02:00
parent 8ddbbcecfa
commit aec80529ca
2 changed files with 338 additions and 484 deletions
@@ -11,6 +11,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -117,6 +118,66 @@ struct ProducerKeyInfo {
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
};
struct SameClassConsumerLookupKey {
Operation* sourceOp = nullptr;
size_t resultIndex = 0;
ClassId classId = 0;
bool operator==(const SameClassConsumerLookupKey& other) const {
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
}
};
struct SameClassConsumerLookupKeyInfo {
static SameClassConsumerLookupKey getEmptyKey() {
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static SameClassConsumerLookupKey getTombstoneKey() {
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
}
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
return lhs == rhs;
}
};
struct WholeBatchAssemblyLookupKey {
Operation* sourceOp = nullptr;
size_t resultIndex = 0;
ClassId classId = 0;
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
}
};
struct WholeBatchAssemblyLookupKeyInfo {
static WholeBatchAssemblyLookupKey getEmptyKey() {
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static WholeBatchAssemblyLookupKey getTombstoneKey() {
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
std::numeric_limits<ClassId>::max()};
}
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
}
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
return lhs == rhs;
}
};
using ClassSlotKey = std::pair<ClassId, SlotId>;
struct MaterializedClass {
@@ -270,9 +331,36 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
class AvailableValueStore {
public:
void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; }
struct ExactBatchFragmentRecord {
ProducerKey key;
Value value;
};
void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); }
void record(ProducerKey key, ClassId classId, Value value) {
exactValues[key][classId] = value;
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
if (!batch || key.instance.laneCount == 0)
return;
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
for (ExactBatchFragmentRecord& record : bucket) {
if (!(record.key == key))
continue;
record.value = value;
return;
}
bucket.push_back({key, value});
}
void recordPackedRun(PackedScalarRunValue run) {
size_t runIndex = packedScalarRuns.size();
packedScalarRuns.push_back(std::move(run));
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
}
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const;
@@ -280,7 +368,21 @@ public:
std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
SmallVectorImpl<PackedScalarRunValue>& getPackedScalarRuns() { return packedScalarRuns; }
ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
auto it = packedRunsByProducerResultClass.find(key);
if (it == packedRunsByProducerResultClass.end())
return {};
return it->second;
}
ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
auto it = exactBatchFragmentsByProducerResultClass.find(key);
if (it == exactBatchFragmentsByProducerResultClass.end())
return {};
return it->second;
}
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
private:
std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
@@ -288,6 +390,10 @@ private:
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues;
SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<ExactBatchFragmentRecord, 16>, WholeBatchAssemblyLookupKeyInfo>
exactBatchFragmentsByProducerResultClass;
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
packedRunsByProducerResultClass;
};
struct MaterializerState {
@@ -296,7 +402,6 @@ struct MaterializerState {
IRRewriter rewriter;
OperationFolder constantFolder;
int64_t& nextChannelId;
SmallVector<MaterializedClass, 8> classes;
DenseMap<CpuId, ClassId> cpuToClass;
DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
@@ -305,7 +410,8 @@ struct MaterializerState {
DenseSet<ClassSlotKey> materializedLogicalSlots;
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers;
DenseMap<SameClassConsumerLookupKey, SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
sameClassConsumerIndex;
DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches;
DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
DenseMap<Value, bool> liveExternalUseCache;
@@ -317,7 +423,9 @@ struct MaterializerState {
DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps;
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
MaterializerState(func::FuncOp func,
const MergeScheduleResult& schedule,
int64_t& nextChannelId)
: func(func),
schedule(schedule),
rewriter(func.getContext()),
@@ -428,6 +536,14 @@ std::optional<ProducerKey> getContiguousProducerRangeForKeys(ArrayRef<ProducerKe
return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex);
}
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) {
return {sourceOp, resultIndex, classId};
}
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) {
return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId);
}
FailureOr<RankedTensorType> getPackedBatchTensorType(Type laneType, size_t laneCount) {
auto tensorType = dyn_cast<RankedTensorType>(laneType);
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
@@ -1172,14 +1288,12 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) {
for (PackedScalarRunValue& run : packedScalarRuns) {
if (run.targetClass != classId)
continue;
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys);
if (!slotKey || !containsProducerKey(*slotKey, key))
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
if (!contiguousKey || !containsProducerKey(*contiguousKey, key))
continue;
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
@@ -1197,12 +1311,13 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
Value slotPacked =
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
if (*slotKey == key) {
if (*contiguousKey == key) {
record(key, classId, slotPacked);
return slotPacked;
}
std::optional<Value> sliced = extractPackedProducerSlice(state, materializedClass, *slotKey, slotPacked, key);
std::optional<Value> sliced =
extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key);
if (!sliced)
return std::nullopt;
@@ -1216,57 +1331,45 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) {
for (IndexedBatchRunValue& run : indexedBatchRuns) {
if (run.targetClass != classId)
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
for (const PackedScalarRunSlot& slot : run.slots)
if (llvm::is_contained(slot.keys, key))
return &run;
for (const PackedScalarRunSlot& slot : run.slots) {
if (!llvm::is_contained(slot.keys, key))
continue;
return &run;
}
}
return nullptr;
}
std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) {
if (std::optional<Value> exact = lookupExact(key, classId))
if (std::optional<Value> exact = lookupExact(key, classId)) {
return exact;
}
if (std::optional<Value> packedRunValue = lookupPackedRun(state, key, classId))
return packedRunValue;
MaterializedClass& materializedClass = state.classes[classId];
ProducerKey containingKey;
Value containingValue;
bool foundContainingValue = false;
for (auto& entry : exactValues) {
ProducerKey candidateKey = entry.first;
if (!containsProducerKey(candidateKey, key))
for (const auto& [candidateKey, classValues] : exactValues) {
if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key))
continue;
auto valueIt = entry.second.find(classId);
if (valueIt == entry.second.end())
auto valueIt = classValues.find(classId);
if (valueIt == classValues.end())
continue;
containingKey = candidateKey;
containingValue = valueIt->second;
foundContainingValue = true;
break;
std::optional<Value> slice =
extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key);
if (!slice)
return std::nullopt;
record(key, classId, *slice);
return *slice;
}
if (!foundContainingValue)
return std::nullopt;
std::optional<Value> slice =
extractPackedProducerSlice(state, materializedClass, containingKey, containingValue, key);
if (!slice)
return std::nullopt;
record(key, classId, *slice);
return *slice;
return std::nullopt;
}
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
@@ -1389,13 +1492,13 @@ Value createIndexedIndexValue(MaterializerState& state,
bool allowExhaustiveTiledSearch) {
assert(!values.empty() && "expected at least one indexed value");
if (allEqual(values))
if (allEqual(values)) {
return getOrCreateIndexConstant(state.constantFolder, anchor, values.front());
}
if (std::optional<IndexedIndexPattern> pattern =
getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch))
return createAffineIndexValue(state, *pattern, index, loc);
Value table = createIndexTensorConstant(state, anchor, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
}
@@ -1578,7 +1681,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
if (sourceClass == targetClass) {
state.sameClassConsumers[producerKey].insert(targetClass);
SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass};
SmallVector<ProducerKey, 4>& bucket = state.sameClassConsumerIndex[lookupKey];
if (!llvm::is_contained(bucket, producerKey))
bucket.push_back(producerKey);
continue;
}
@@ -2899,11 +3005,6 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return success();
}
struct WholeBatchAssemblyRange {
uint32_t laneStart = 0;
uint32_t laneCount = 0;
};
struct DirectWholeBatchFragment {
ProducerKey key;
Value fragment;
@@ -2933,31 +3034,60 @@ struct WholeBatchFragmentGroup {
struct WholeBatchAssemblyPlan {
RankedTensorType resultType;
int64_t rowsPerLane = 0;
uint32_t batchLaneCount = 0;
uint32_t coveredLaneCount = 0;
SmallVector<WholeBatchAssemblyRange, 16> coveredRanges;
SmallVector<uint8_t, 64> coveredLanes;
SmallVector<PackedScalarRunValue*, 8> packedRuns;
SmallVector<DirectWholeBatchFragment, 16> directFragments;
};
bool wholeBatchRangeOverlaps(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t laneStart, uint32_t laneCount) {
uint32_t laneEnd = laneStart + laneCount;
for (WholeBatchAssemblyRange range : ranges) {
uint32_t rangeEnd = range.laneStart + range.laneCount;
if (laneStart < rangeEnd && range.laneStart < laneEnd)
return true;
}
return false;
bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) {
return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0;
}
bool wholeBatchLaneCovered(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t lane) {
for (WholeBatchAssemblyRange range : ranges)
if (range.laneStart <= lane && lane < range.laneStart + range.laneCount)
bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
if (laneCount == 0)
return false;
if (laneStart >= plan.coveredLanes.size())
return false;
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, plan.coveredLanes.size());
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
if (plan.coveredLanes[lane] != 0)
return true;
return false;
}
void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
plan.coveredRanges.push_back({laneStart, laneCount});
assert(laneCount != 0 && "cannot cover an empty whole-batch range");
assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds");
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) {
if (plan.coveredLanes[lane] != 0)
continue;
plan.coveredLanes[lane] = 1;
++plan.coveredLaneCount;
}
}
bool localLaneRangeOverlaps(ArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
if (laneCount == 0)
return false;
if (laneStart >= covered.size())
return false;
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, covered.size());
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
if (covered[lane] != 0)
return true;
return false;
}
void markLocalLaneRangeCovered(MutableArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds");
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane)
covered[lane] = 1;
}
LogicalResult
@@ -3118,13 +3248,14 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
MaterializedClass& targetClass,
ProducerKey key,
WholeBatchAssemblyPlan& plan) {
for (PackedScalarRunValue& run : state.availableValues.getPackedScalarRuns()) {
if (run.targetClass != targetClass.id)
continue;
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
ArrayRef<size_t> runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey);
SmallVector<WholeBatchAssemblyRange, 16> runRanges;
for (size_t runIndex : runIndices) {
PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex);
SmallVector<ProducerKey, 16> runKeys;
SmallVector<uint8_t, 64> runCoveredLanes(plan.batchLaneCount, 0);
for (const PackedScalarRunSlot& slot : run.slots) {
for (ProducerKey fragmentKey : slot.keys) {
@@ -3134,23 +3265,24 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
if (fragmentKey.instance.laneCount == 0)
return failure();
if (wholeBatchRangeOverlaps(plan.coveredRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
return failure();
if (wholeBatchRangeOverlaps(runRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
return failure();
runRanges.push_back({fragmentKey.instance.laneStart, fragmentKey.instance.laneCount});
markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount);
runKeys.push_back(fragmentKey);
}
}
if (runRanges.empty())
if (runKeys.empty())
continue;
plan.packedRuns.push_back(&run);
for (WholeBatchAssemblyRange range : runRanges)
recordWholeBatchCoverage(plan, range.laneStart, range.laneCount);
for (ProducerKey runKey : runKeys)
recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount);
}
return success();
@@ -3161,44 +3293,77 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state,
SpatComputeBatch batch,
ProducerKey key,
WholeBatchAssemblyPlan& plan) {
struct CandidateFragment {
ProducerKey key;
Value value;
};
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
uint32_t lane = 0;
if (plan.coveredLaneCount == plan.batchLaneCount) {
return success();
}
while (lane < batchLaneCount) {
if (wholeBatchLaneCovered(plan.coveredRanges, lane)) {
++lane;
WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
ArrayRef<AvailableValueStore::ExactBatchFragmentRecord> indexedFragments =
state.availableValues.getExactFragmentsForWholeBatch(lookupKey);
SmallVector<CandidateFragment, 16> candidates;
candidates.reserve(indexedFragments.size());
for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) {
ProducerKey candidateKey = record.key;
if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex
|| candidateKey.instance.laneCount == 0)
continue;
if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount))
continue;
auto fragmentType = dyn_cast<RankedTensorType>(record.value.getType());
if (!fragmentType)
continue;
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(candidateKey.instance.laneCount);
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
continue;
candidates.push_back({candidateKey, record.value});
}
llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) {
if (lhs.key.instance.laneStart != rhs.key.instance.laneStart)
return lhs.key.instance.laneStart < rhs.key.instance.laneStart;
return lhs.key.instance.laneCount > rhs.key.instance.laneCount;
});
size_t candidateCursor = 0;
uint32_t lane = 0;
while (lane < batchLaneCount) {
while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) {
++lane;
}
bool foundFragment = false;
for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) {
if (wholeBatchRangeOverlaps(plan.coveredRanges, lane, laneCount))
continue;
ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex);
std::optional<Value> fragment = state.availableValues.lookupExact(candidate, targetClass.id);
if (!fragment)
continue;
auto fragmentType = dyn_cast<RankedTensorType>(fragment->getType());
if (!fragmentType)
return failure();
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(laneCount);
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
return failure();
plan.directFragments.push_back({candidate, *fragment});
recordWholeBatchCoverage(plan, lane, laneCount);
lane += laneCount;
foundFragment = true;
if (lane >= batchLaneCount)
break;
while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane)
++candidateCursor;
size_t candidateIndex = candidateCursor;
const CandidateFragment* best = nullptr;
while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) {
const CandidateFragment& candidate = candidates[candidateIndex];
if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) {
best = &candidate;
break;
}
++candidateIndex;
}
if (!foundFragment)
if (!best)
return failure();
plan.directFragments.push_back({best->key, best->value});
recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount);
lane += best->key.instance.laneCount;
}
return success();
@@ -3291,11 +3456,11 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
}
for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) {
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys);
if (!slotKey)
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
if (!contiguousKey)
return failure();
groupIt->slotIndices.push_back(slotIndex);
groupIt->outputOffsets.push_back(static_cast<int64_t>(slotKey->instance.laneStart) * plan.rowsPerLane);
groupIt->outputOffsets.push_back(static_cast<int64_t>(contiguousKey->instance.laneStart) * plan.rowsPerLane);
}
}
@@ -3409,10 +3574,15 @@ FailureOr<WholeBatchAssemblyPlan> buildWholeBatchAssemblyPlan(MaterializerState&
WholeBatchAssemblyPlan plan;
plan.resultType = resultTensorType;
plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast<int64_t>(batchLaneCount);
plan.batchLaneCount = batchLaneCount;
plan.coveredLanes.assign(batchLaneCount, 0);
if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan)))
return failure();
if (plan.coveredLaneCount == plan.batchLaneCount)
return plan;
if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan)))
return failure();
@@ -4181,7 +4351,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
auto sourceBatch = cast<SpatComputeBatch>(sourceOp);
SmallVector<Type, 4>& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch);
SmallVector<Value, 4> initValues;
for (size_t resultIndex : group.resultIndices) {
if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex])
return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run");
@@ -4197,7 +4366,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
SmallVector<int64_t, 8> logicalLanes;
logicalLanes.reserve(run.size());
for (const MaterializationRunSlot& slot : run) {
if (slot.peers.size() != 1)
return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot");
@@ -4215,34 +4383,34 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = buildNormalizedScfFor(
state.rewriter,
loc,
lowerBound,
upperBound,
step,
ValueRange(initValues),
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
state.rewriter,
loc,
lowerBound,
upperBound,
step,
ValueRange(initValues),
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state,
targetClass,
run.front().peers.front(),
sourceLane,
group.resultIndices,
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
if (failed(produced))
return failure();
FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state,
targetClass,
run.front().peers.front(),
sourceLane,
group.resultIndices,
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
if (failed(produced))
return failure();
yielded.reserve(produced->size());
for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
auto fragmentType = cast<RankedTensorType>(output.getType());
Value acc = iterArgs[outputIndex];
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
}
return success();
});
yielded.reserve(produced->size());
for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
auto fragmentType = cast<RankedTensorType>(output.getType());
Value acc = iterArgs[outputIndex];
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
}
return success();
});
if (failed(loop))
return failure();
@@ -4466,14 +4634,14 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
}
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
for (const auto& [key, consumers] : state.sameClassConsumers) {
if (!consumers.contains(classId))
continue;
if (!sameProducerResult(key, producerKey))
continue;
if (containsProducerKey(key, producerKey) || containsProducerKey(producerKey, key))
SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId};
auto it = state.sameClassConsumerIndex.find(lookupKey);
if (it == state.sameClassConsumerIndex.end())
return false;
for (ProducerKey existing : it->second)
if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing))
return true;
}
return false;
}
@@ -4488,6 +4656,7 @@ bool canCompactBatchClassRun(MaterializerState& state,
ArrayRef<Value> outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front());
for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) {
(void) ignored;
for (const MaterializationRunSlot& slot : run) {
if (slot.peers.empty())
return false;
@@ -4533,7 +4702,8 @@ Value createBatchClassRunSourceLane(MaterializerState& state,
SmallVector<int64_t, 16> sourceLanes;
sourceLanes.reserve(run.size() * targetClass.cpus.size());
for (const MaterializationRunSlot& slot : run) {
for (auto [runSlotIndex, slot] : llvm::enumerate(run)) {
(void) runSlotIndex;
assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane");
for (const ComputeInstance& peer : slot.peers)
sourceLanes.push_back(peer.laneStart);
@@ -4577,7 +4747,6 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
plan.messages.targetCoreIds.reserve(messageCount);
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
(void) slotIndex;
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
if (failed(checkedSourceCpu))
@@ -4590,6 +4759,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
return failure();
plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
}
(void) slotIndex;
}
plans.push_back(std::move(plan));
@@ -4773,7 +4943,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
return success();
}
LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) {
LogicalResult materializeInstanceSlot(MaterializerState& state,
const ComputeInstance& instance) {
auto cpuIt = state.schedule.computeToCpuMap.find(instance);
if (cpuIt == state.schedule.computeToCpuMap.end())
return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance");
@@ -4794,8 +4965,7 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
return success();
if (isa<SpatComputeBatch>(instance.op)) {
FailureOr<MaterializationRun> run =
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
if (succeeded(run)) {
if (!targetClass.isBatch)
@@ -4924,6 +5094,7 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
return failure();
LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody());
(void) _;
return success();
}
@@ -1,6 +1,5 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
@@ -14,20 +13,14 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
@@ -51,83 +44,6 @@ using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
using spatial::getProducerValueRef;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
~ScopedMergePhaseTimer() {
if (!enabled)
return;
auto elapsed = std::chrono::steady_clock::now() - start;
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
}
private:
bool enabled = false;
std::string phase;
std::chrono::steady_clock::time_point start;
};
struct MergeIrCounts {
uint64_t topLevelComputeCount = 0;
uint64_t topLevelComputeBatchCount = 0;
uint64_t scalarChannelSendCount = 0;
uint64_t scalarChannelReceiveCount = 0;
uint64_t wvmmCount = 0;
uint64_t vaddCount = 0;
uint64_t scfForCount = 0;
};
MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
MergeIrCounts counts;
auto countComputeBodyOps = [&](Operation* op) {
op->walk([&](Operation* nestedOp) {
if (isa<spatial::SpatChannelSendOp>(nestedOp))
++counts.scalarChannelSendCount;
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
++counts.scalarChannelReceiveCount;
else if (isa<spatial::SpatVMMOp>(nestedOp))
++counts.wvmmCount;
else if (isa<spatial::SpatVAddOp>(nestedOp))
++counts.vaddCount;
else if (isa<scf::ForOp>(nestedOp))
++counts.scfForCount;
});
};
for (auto compute : funcOp.getOps<SpatCompute>()) {
++counts.topLevelComputeCount;
countComputeBodyOps(compute.getOperation());
}
for (auto batch : funcOp.getOps<SpatComputeBatch>()) {
++counts.topLevelComputeBatchCount;
countComputeBodyOps(batch.getOperation());
}
return counts;
}
void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
if (!isMergeProfilingEnabled())
return;
MergeIrCounts counts = collectMergeIrCounts(funcOp);
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
@@ -138,16 +54,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
return std::nullopt;
}
struct ComputeMotifInfo {
uint64_t instructionCount = 0;
uint64_t weightedVmmCount = 0;
};
void appendUnique(SmallVector<size_t>& values, size_t value) {
if (!llvm::is_contained(values, value))
values.push_back(value);
}
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
if (!compute->hasOneUse())
return false;
@@ -266,212 +172,6 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
}
}
void emitMotifProfile(func::FuncOp funcOp) {
if (!std::getenv("DCP_MOTIF_PROFILE"))
return;
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseMap<SpatCompute, size_t> computeToIndex;
computeToIndex.reserve(computes.size());
for (auto [index, compute] : llvm::enumerate(computes))
computeToIndex[compute] = index;
SmallVector<ComputeMotifInfo> computeInfos(computes.size());
SmallVector<SmallVector<size_t>> parents(computes.size());
SmallVector<SmallVector<size_t>> children(computes.size());
uint64_t weightedVmmNodeCount = 0;
uint64_t weightedVmmOpCount = 0;
uint64_t edgeCount = 0;
for (auto [index, compute] : llvm::enumerate(computes)) {
ComputeMotifInfo& info = computeInfos[index];
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
if (info.weightedVmmCount > 0) {
weightedVmmNodeCount++;
weightedVmmOpCount += info.weightedVmmCount;
}
for (Value input : compute.getInputs()) {
auto parent = dyn_cast<SpatCompute>(input.getDefiningOp());
if (!parent || parent == compute)
continue;
auto parentIt = computeToIndex.find(parent);
if (parentIt == computeToIndex.end())
continue;
size_t parentIndex = parentIt->second;
size_t oldParentCount = parents[index].size();
appendUnique(parents[index], parentIndex);
if (parents[index].size() != oldParentCount) {
appendUnique(children[parentIndex], index);
edgeCount++;
}
}
}
uint64_t maxFanIn = 0;
uint64_t maxFanOut = 0;
uint64_t fanIn16 = 0;
uint64_t fanIn64 = 0;
uint64_t fanIn256 = 0;
uint64_t fanOut16 = 0;
uint64_t fanOut64 = 0;
uint64_t fanOut256 = 0;
for (size_t index = 0; index < computes.size(); ++index) {
uint64_t fanIn = parents[index].size();
uint64_t fanOut = children[index].size();
maxFanIn = std::max(maxFanIn, fanIn);
maxFanOut = std::max(maxFanOut, fanOut);
fanIn16 += fanIn >= 16;
fanIn64 += fanIn >= 64;
fanIn256 += fanIn >= 256;
fanOut16 += fanOut >= 16;
fanOut64 += fanOut >= 64;
fanOut256 += fanOut >= 256;
}
uint64_t serialChainCount = 0;
uint64_t serialChainNodeCount = 0;
uint64_t maxSerialChain = 0;
for (size_t index = 0; index < computes.size(); ++index) {
if (parents[index].size() == 1 && children[parents[index][0]].size() == 1)
continue;
uint64_t chainLength = 1;
size_t current = index;
while (children[current].size() == 1) {
size_t child = children[current][0];
if (parents[child].size() != 1)
break;
chainLength++;
current = child;
}
if (chainLength >= 2) {
serialChainCount++;
serialChainNodeCount += chainLength;
maxSerialChain = std::max(maxSerialChain, chainLength);
}
}
SmallVector<size_t> incomingEdgeCount;
incomingEdgeCount.reserve(parents.size());
for (ArrayRef<size_t> parentList : parents)
incomingEdgeCount.push_back(parentList.size());
SmallVector<uint64_t> level(computes.size(), 0);
SmallVector<size_t> readyNodes;
readyNodes.reserve(computes.size());
for (size_t index = 0; index < computes.size(); ++index)
if (incomingEdgeCount[index] == 0)
readyNodes.push_back(index);
size_t readyIndex = 0;
while (readyIndex != readyNodes.size()) {
size_t current = readyNodes[readyIndex++];
for (size_t child : children[current]) {
level[child] = std::max(level[child], level[current] + 1);
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push_back(child);
}
}
SmallVector<uint64_t> weightedVmmNodesByLevel;
for (size_t index = 0; index < computes.size(); ++index) {
if (computeInfos[index].weightedVmmCount == 0)
continue;
if (weightedVmmNodesByLevel.size() <= level[index])
weightedVmmNodesByLevel.resize(level[index] + 1, 0);
weightedVmmNodesByLevel[level[index]]++;
}
uint64_t maxWeightedVmmLevel = 0;
uint64_t wideWeightedVmmLevels64 = 0;
uint64_t wideWeightedVmmLevels256 = 0;
for (uint64_t count : weightedVmmNodesByLevel) {
maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count);
wideWeightedVmmLevels64 += count >= 64;
wideWeightedVmmLevels256 += count >= 256;
}
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
SmallVector<ShapeKey> weightedVmmShapeKeys;
for (auto [index, compute] : llvm::enumerate(computes)) {
const ComputeMotifInfo& info = computeInfos[index];
if (info.weightedVmmCount == 0)
continue;
weightedVmmShapeKeys.push_back({info.instructionCount,
info.weightedVmmCount,
static_cast<uint64_t>(compute.getWeights().size()),
static_cast<uint64_t>(compute.getInputs().size()),
static_cast<uint64_t>(parents[index].size()),
static_cast<uint64_t>(children[index].size())});
}
llvm::sort(weightedVmmShapeKeys);
SmallVector<std::pair<uint64_t, ShapeKey>> weightedVmmShapeCounts;
for (size_t index = 0; index < weightedVmmShapeKeys.size();) {
size_t next = index + 1;
while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index])
next++;
weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]});
index = next;
}
llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first)
return lhs.first > rhs.first;
return lhs.second < rhs.second;
});
llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} "
"serialChains={4} serialChainNodes={5} maxSerialChain={6} "
"maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} "
"fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n",
computes.size(),
edgeCount,
weightedVmmNodeCount,
weightedVmmOpCount,
serialChainCount,
serialChainNodeCount,
maxSerialChain,
maxFanIn,
maxFanOut,
fanIn16,
fanIn64,
fanIn256,
fanOut16,
fanOut64,
fanOut256,
readyNodes.size());
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} "
"shapeGroups={4}\n",
weightedVmmNodesByLevel.size(),
maxWeightedVmmLevel,
wideWeightedVmmLevels64,
wideWeightedVmmLevels256,
weightedVmmShapeCounts.size());
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
auto [count, shape] = weightedVmmShapeCounts[rank];
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
rank,
count,
insts,
vmmOps,
weights,
inputs,
fanIn,
fanOut);
}
}
void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) {
std::fstream file = openReportFile(name);
if (!file.is_open())
@@ -628,44 +328,27 @@ public:
void runOnOperation() override {
func::FuncOp func = getOperation();
{
ScopedMergePhaseTimer timer("trivial-serial-merge");
mergeTriviallyConnectedComputes(func);
}
if (std::getenv("DCP_MOTIF_PROFILE"))
emitMotifProfile(func);
mergeTriviallyConnectedComputes(func);
const spatial::MergeScheduleResult* analysisResult = nullptr;
{
ScopedMergePhaseTimer timer("scheduling-analysis");
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
}
{
ScopedMergePhaseTimer timer("schedule-materialization");
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
signalPassFailure();
return;
}
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
signalPassFailure();
return;
}
emitMergeIrCounts("after-materialization", func);
{
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure();
return;
}
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
signalPassFailure();
return;
}
emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure();
return;
}
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
signalPassFailure();
return;
}
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
}
};