much faster MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
+322
-151
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -117,6 +118,66 @@ struct ProducerKeyInfo {
|
||||
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKey {
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const SameClassConsumerLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct SameClassConsumerLookupKeyInfo {
|
||||
static SameClassConsumerLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static SameClassConsumerLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKey {
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t resultIndex = 0;
|
||||
ClassId classId = 0;
|
||||
|
||||
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
|
||||
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||
}
|
||||
};
|
||||
|
||||
struct WholeBatchAssemblyLookupKeyInfo {
|
||||
static WholeBatchAssemblyLookupKey getEmptyKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static WholeBatchAssemblyLookupKey getTombstoneKey() {
|
||||
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<ClassId>::max()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
|
||||
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
|
||||
}
|
||||
|
||||
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
using ClassSlotKey = std::pair<ClassId, SlotId>;
|
||||
|
||||
struct MaterializedClass {
|
||||
@@ -270,9 +331,36 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
|
||||
|
||||
class AvailableValueStore {
|
||||
public:
|
||||
void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; }
|
||||
struct ExactBatchFragmentRecord {
|
||||
ProducerKey key;
|
||||
Value value;
|
||||
};
|
||||
|
||||
void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); }
|
||||
void record(ProducerKey key, ClassId classId, Value value) {
|
||||
exactValues[key][classId] = value;
|
||||
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||
if (!batch || key.instance.laneCount == 0)
|
||||
return;
|
||||
|
||||
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
|
||||
SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
|
||||
for (ExactBatchFragmentRecord& record : bucket) {
|
||||
if (!(record.key == key))
|
||||
continue;
|
||||
record.value = value;
|
||||
return;
|
||||
}
|
||||
bucket.push_back({key, value});
|
||||
}
|
||||
|
||||
void recordPackedRun(PackedScalarRunValue run) {
|
||||
size_t runIndex = packedScalarRuns.size();
|
||||
packedScalarRuns.push_back(std::move(run));
|
||||
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
|
||||
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
|
||||
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
|
||||
}
|
||||
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
|
||||
|
||||
std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const;
|
||||
@@ -280,7 +368,21 @@ public:
|
||||
std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
|
||||
|
||||
SmallVectorImpl<PackedScalarRunValue>& getPackedScalarRuns() { return packedScalarRuns; }
|
||||
ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = packedRunsByProducerResultClass.find(key);
|
||||
if (it == packedRunsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||
auto it = exactBatchFragmentsByProducerResultClass.find(key);
|
||||
if (it == exactBatchFragmentsByProducerResultClass.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
|
||||
|
||||
private:
|
||||
std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||
@@ -288,6 +390,10 @@ private:
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues;
|
||||
SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
|
||||
SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
|
||||
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<ExactBatchFragmentRecord, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||
exactBatchFragmentsByProducerResultClass;
|
||||
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||
packedRunsByProducerResultClass;
|
||||
};
|
||||
|
||||
struct MaterializerState {
|
||||
@@ -296,7 +402,6 @@ struct MaterializerState {
|
||||
IRRewriter rewriter;
|
||||
OperationFolder constantFolder;
|
||||
int64_t& nextChannelId;
|
||||
|
||||
SmallVector<MaterializedClass, 8> classes;
|
||||
DenseMap<CpuId, ClassId> cpuToClass;
|
||||
DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
|
||||
@@ -305,7 +410,8 @@ struct MaterializerState {
|
||||
DenseSet<ClassSlotKey> materializedLogicalSlots;
|
||||
|
||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers;
|
||||
DenseMap<SameClassConsumerLookupKey, SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
|
||||
sameClassConsumerIndex;
|
||||
DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches;
|
||||
DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
|
||||
DenseMap<Value, bool> liveExternalUseCache;
|
||||
@@ -317,7 +423,9 @@ struct MaterializerState {
|
||||
DenseMap<Value, Value> hostReplacements;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
|
||||
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
||||
MaterializerState(func::FuncOp func,
|
||||
const MergeScheduleResult& schedule,
|
||||
int64_t& nextChannelId)
|
||||
: func(func),
|
||||
schedule(schedule),
|
||||
rewriter(func.getContext()),
|
||||
@@ -428,6 +536,14 @@ std::optional<ProducerKey> getContiguousProducerRangeForKeys(ArrayRef<ProducerKe
|
||||
return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex);
|
||||
}
|
||||
|
||||
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) {
|
||||
return {sourceOp, resultIndex, classId};
|
||||
}
|
||||
|
||||
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) {
|
||||
return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId);
|
||||
}
|
||||
|
||||
FailureOr<RankedTensorType> getPackedBatchTensorType(Type laneType, size_t laneCount) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(laneType);
|
||||
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
|
||||
@@ -1172,14 +1288,12 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
|
||||
|
||||
std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) {
|
||||
for (PackedScalarRunValue& run : packedScalarRuns) {
|
||||
if (run.targetClass != classId)
|
||||
continue;
|
||||
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
continue;
|
||||
|
||||
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
|
||||
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||
if (!slotKey || !containsProducerKey(*slotKey, key))
|
||||
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||
if (!contiguousKey || !containsProducerKey(*contiguousKey, key))
|
||||
continue;
|
||||
|
||||
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
|
||||
@@ -1197,12 +1311,13 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
||||
Value slotPacked =
|
||||
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
|
||||
|
||||
if (*slotKey == key) {
|
||||
if (*contiguousKey == key) {
|
||||
record(key, classId, slotPacked);
|
||||
return slotPacked;
|
||||
}
|
||||
|
||||
std::optional<Value> sliced = extractPackedProducerSlice(state, materializedClass, *slotKey, slotPacked, key);
|
||||
std::optional<Value> sliced =
|
||||
extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key);
|
||||
if (!sliced)
|
||||
return std::nullopt;
|
||||
|
||||
@@ -1216,57 +1331,45 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
||||
|
||||
IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) {
|
||||
for (IndexedBatchRunValue& run : indexedBatchRuns) {
|
||||
if (run.targetClass != classId)
|
||||
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
continue;
|
||||
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
continue;
|
||||
|
||||
for (const PackedScalarRunSlot& slot : run.slots)
|
||||
if (llvm::is_contained(slot.keys, key))
|
||||
return &run;
|
||||
for (const PackedScalarRunSlot& slot : run.slots) {
|
||||
if (!llvm::is_contained(slot.keys, key))
|
||||
continue;
|
||||
return &run;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) {
|
||||
if (std::optional<Value> exact = lookupExact(key, classId))
|
||||
|
||||
if (std::optional<Value> exact = lookupExact(key, classId)) {
|
||||
return exact;
|
||||
}
|
||||
|
||||
if (std::optional<Value> packedRunValue = lookupPackedRun(state, key, classId))
|
||||
return packedRunValue;
|
||||
|
||||
MaterializedClass& materializedClass = state.classes[classId];
|
||||
|
||||
ProducerKey containingKey;
|
||||
Value containingValue;
|
||||
bool foundContainingValue = false;
|
||||
|
||||
for (auto& entry : exactValues) {
|
||||
ProducerKey candidateKey = entry.first;
|
||||
if (!containsProducerKey(candidateKey, key))
|
||||
for (const auto& [candidateKey, classValues] : exactValues) {
|
||||
if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key))
|
||||
continue;
|
||||
|
||||
auto valueIt = entry.second.find(classId);
|
||||
if (valueIt == entry.second.end())
|
||||
auto valueIt = classValues.find(classId);
|
||||
if (valueIt == classValues.end())
|
||||
continue;
|
||||
|
||||
containingKey = candidateKey;
|
||||
containingValue = valueIt->second;
|
||||
foundContainingValue = true;
|
||||
break;
|
||||
std::optional<Value> slice =
|
||||
extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key);
|
||||
if (!slice)
|
||||
return std::nullopt;
|
||||
|
||||
record(key, classId, *slice);
|
||||
return *slice;
|
||||
}
|
||||
|
||||
if (!foundContainingValue)
|
||||
return std::nullopt;
|
||||
|
||||
std::optional<Value> slice =
|
||||
extractPackedProducerSlice(state, materializedClass, containingKey, containingValue, key);
|
||||
if (!slice)
|
||||
return std::nullopt;
|
||||
|
||||
record(key, classId, *slice);
|
||||
return *slice;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||
@@ -1389,13 +1492,13 @@ Value createIndexedIndexValue(MaterializerState& state,
|
||||
bool allowExhaustiveTiledSearch) {
|
||||
assert(!values.empty() && "expected at least one indexed value");
|
||||
|
||||
if (allEqual(values))
|
||||
if (allEqual(values)) {
|
||||
return getOrCreateIndexConstant(state.constantFolder, anchor, values.front());
|
||||
}
|
||||
|
||||
if (std::optional<IndexedIndexPattern> pattern =
|
||||
getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch))
|
||||
return createAffineIndexValue(state, *pattern, index, loc);
|
||||
|
||||
Value table = createIndexTensorConstant(state, anchor, values);
|
||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
||||
}
|
||||
@@ -1578,7 +1681,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
||||
|
||||
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
|
||||
if (sourceClass == targetClass) {
|
||||
state.sameClassConsumers[producerKey].insert(targetClass);
|
||||
SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass};
|
||||
SmallVector<ProducerKey, 4>& bucket = state.sameClassConsumerIndex[lookupKey];
|
||||
if (!llvm::is_contained(bucket, producerKey))
|
||||
bucket.push_back(producerKey);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -2899,11 +3005,6 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
||||
return success();
|
||||
}
|
||||
|
||||
struct WholeBatchAssemblyRange {
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 0;
|
||||
};
|
||||
|
||||
struct DirectWholeBatchFragment {
|
||||
ProducerKey key;
|
||||
Value fragment;
|
||||
@@ -2933,31 +3034,60 @@ struct WholeBatchFragmentGroup {
|
||||
struct WholeBatchAssemblyPlan {
|
||||
RankedTensorType resultType;
|
||||
int64_t rowsPerLane = 0;
|
||||
uint32_t batchLaneCount = 0;
|
||||
uint32_t coveredLaneCount = 0;
|
||||
|
||||
SmallVector<WholeBatchAssemblyRange, 16> coveredRanges;
|
||||
SmallVector<uint8_t, 64> coveredLanes;
|
||||
SmallVector<PackedScalarRunValue*, 8> packedRuns;
|
||||
SmallVector<DirectWholeBatchFragment, 16> directFragments;
|
||||
};
|
||||
|
||||
bool wholeBatchRangeOverlaps(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t laneStart, uint32_t laneCount) {
|
||||
uint32_t laneEnd = laneStart + laneCount;
|
||||
for (WholeBatchAssemblyRange range : ranges) {
|
||||
uint32_t rangeEnd = range.laneStart + range.laneCount;
|
||||
if (laneStart < rangeEnd && range.laneStart < laneEnd)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) {
|
||||
return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0;
|
||||
}
|
||||
|
||||
bool wholeBatchLaneCovered(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t lane) {
|
||||
for (WholeBatchAssemblyRange range : ranges)
|
||||
if (range.laneStart <= lane && lane < range.laneStart + range.laneCount)
|
||||
bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
|
||||
if (laneCount == 0)
|
||||
return false;
|
||||
if (laneStart >= plan.coveredLanes.size())
|
||||
return false;
|
||||
|
||||
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, plan.coveredLanes.size());
|
||||
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
|
||||
if (plan.coveredLanes[lane] != 0)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
|
||||
plan.coveredRanges.push_back({laneStart, laneCount});
|
||||
assert(laneCount != 0 && "cannot cover an empty whole-batch range");
|
||||
assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds");
|
||||
|
||||
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) {
|
||||
if (plan.coveredLanes[lane] != 0)
|
||||
continue;
|
||||
plan.coveredLanes[lane] = 1;
|
||||
++plan.coveredLaneCount;
|
||||
}
|
||||
}
|
||||
|
||||
bool localLaneRangeOverlaps(ArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
|
||||
if (laneCount == 0)
|
||||
return false;
|
||||
if (laneStart >= covered.size())
|
||||
return false;
|
||||
|
||||
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, covered.size());
|
||||
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
|
||||
if (covered[lane] != 0)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void markLocalLaneRangeCovered(MutableArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
|
||||
assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds");
|
||||
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane)
|
||||
covered[lane] = 1;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
@@ -3118,13 +3248,14 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
ProducerKey key,
|
||||
WholeBatchAssemblyPlan& plan) {
|
||||
for (PackedScalarRunValue& run : state.availableValues.getPackedScalarRuns()) {
|
||||
if (run.targetClass != targetClass.id)
|
||||
continue;
|
||||
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
continue;
|
||||
WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
|
||||
ArrayRef<size_t> runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey);
|
||||
|
||||
SmallVector<WholeBatchAssemblyRange, 16> runRanges;
|
||||
for (size_t runIndex : runIndices) {
|
||||
PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex);
|
||||
|
||||
SmallVector<ProducerKey, 16> runKeys;
|
||||
SmallVector<uint8_t, 64> runCoveredLanes(plan.batchLaneCount, 0);
|
||||
|
||||
for (const PackedScalarRunSlot& slot : run.slots) {
|
||||
for (ProducerKey fragmentKey : slot.keys) {
|
||||
@@ -3134,23 +3265,24 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
|
||||
if (fragmentKey.instance.laneCount == 0)
|
||||
return failure();
|
||||
|
||||
if (wholeBatchRangeOverlaps(plan.coveredRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
||||
if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
||||
return failure();
|
||||
|
||||
if (wholeBatchRangeOverlaps(runRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
||||
if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
||||
return failure();
|
||||
|
||||
runRanges.push_back({fragmentKey.instance.laneStart, fragmentKey.instance.laneCount});
|
||||
markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount);
|
||||
runKeys.push_back(fragmentKey);
|
||||
}
|
||||
}
|
||||
|
||||
if (runRanges.empty())
|
||||
if (runKeys.empty())
|
||||
continue;
|
||||
|
||||
plan.packedRuns.push_back(&run);
|
||||
|
||||
for (WholeBatchAssemblyRange range : runRanges)
|
||||
recordWholeBatchCoverage(plan, range.laneStart, range.laneCount);
|
||||
for (ProducerKey runKey : runKeys)
|
||||
recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount);
|
||||
}
|
||||
|
||||
return success();
|
||||
@@ -3161,44 +3293,77 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state,
|
||||
SpatComputeBatch batch,
|
||||
ProducerKey key,
|
||||
WholeBatchAssemblyPlan& plan) {
|
||||
struct CandidateFragment {
|
||||
ProducerKey key;
|
||||
Value value;
|
||||
};
|
||||
|
||||
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
|
||||
uint32_t lane = 0;
|
||||
if (plan.coveredLaneCount == plan.batchLaneCount) {
|
||||
return success();
|
||||
}
|
||||
|
||||
while (lane < batchLaneCount) {
|
||||
if (wholeBatchLaneCovered(plan.coveredRanges, lane)) {
|
||||
++lane;
|
||||
WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
|
||||
ArrayRef<AvailableValueStore::ExactBatchFragmentRecord> indexedFragments =
|
||||
state.availableValues.getExactFragmentsForWholeBatch(lookupKey);
|
||||
|
||||
SmallVector<CandidateFragment, 16> candidates;
|
||||
candidates.reserve(indexedFragments.size());
|
||||
for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) {
|
||||
ProducerKey candidateKey = record.key;
|
||||
if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex
|
||||
|| candidateKey.instance.laneCount == 0)
|
||||
continue;
|
||||
if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount))
|
||||
continue;
|
||||
|
||||
auto fragmentType = dyn_cast<RankedTensorType>(record.value.getType());
|
||||
if (!fragmentType)
|
||||
continue;
|
||||
|
||||
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(candidateKey.instance.laneCount);
|
||||
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
|
||||
continue;
|
||||
|
||||
candidates.push_back({candidateKey, record.value});
|
||||
}
|
||||
|
||||
llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) {
|
||||
if (lhs.key.instance.laneStart != rhs.key.instance.laneStart)
|
||||
return lhs.key.instance.laneStart < rhs.key.instance.laneStart;
|
||||
return lhs.key.instance.laneCount > rhs.key.instance.laneCount;
|
||||
});
|
||||
|
||||
size_t candidateCursor = 0;
|
||||
uint32_t lane = 0;
|
||||
while (lane < batchLaneCount) {
|
||||
while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) {
|
||||
++lane;
|
||||
}
|
||||
|
||||
bool foundFragment = false;
|
||||
|
||||
for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) {
|
||||
if (wholeBatchRangeOverlaps(plan.coveredRanges, lane, laneCount))
|
||||
continue;
|
||||
|
||||
ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex);
|
||||
std::optional<Value> fragment = state.availableValues.lookupExact(candidate, targetClass.id);
|
||||
if (!fragment)
|
||||
continue;
|
||||
|
||||
auto fragmentType = dyn_cast<RankedTensorType>(fragment->getType());
|
||||
if (!fragmentType)
|
||||
return failure();
|
||||
|
||||
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(laneCount);
|
||||
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
|
||||
return failure();
|
||||
|
||||
plan.directFragments.push_back({candidate, *fragment});
|
||||
recordWholeBatchCoverage(plan, lane, laneCount);
|
||||
|
||||
lane += laneCount;
|
||||
foundFragment = true;
|
||||
if (lane >= batchLaneCount)
|
||||
break;
|
||||
|
||||
while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane)
|
||||
++candidateCursor;
|
||||
|
||||
size_t candidateIndex = candidateCursor;
|
||||
const CandidateFragment* best = nullptr;
|
||||
while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) {
|
||||
const CandidateFragment& candidate = candidates[candidateIndex];
|
||||
if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) {
|
||||
best = &candidate;
|
||||
break;
|
||||
}
|
||||
++candidateIndex;
|
||||
}
|
||||
|
||||
if (!foundFragment)
|
||||
if (!best)
|
||||
return failure();
|
||||
|
||||
plan.directFragments.push_back({best->key, best->value});
|
||||
recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount);
|
||||
lane += best->key.instance.laneCount;
|
||||
}
|
||||
|
||||
return success();
|
||||
@@ -3291,11 +3456,11 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
|
||||
}
|
||||
|
||||
for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) {
|
||||
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||
if (!slotKey)
|
||||
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||
if (!contiguousKey)
|
||||
return failure();
|
||||
groupIt->slotIndices.push_back(slotIndex);
|
||||
groupIt->outputOffsets.push_back(static_cast<int64_t>(slotKey->instance.laneStart) * plan.rowsPerLane);
|
||||
groupIt->outputOffsets.push_back(static_cast<int64_t>(contiguousKey->instance.laneStart) * plan.rowsPerLane);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3409,10 +3574,15 @@ FailureOr<WholeBatchAssemblyPlan> buildWholeBatchAssemblyPlan(MaterializerState&
|
||||
WholeBatchAssemblyPlan plan;
|
||||
plan.resultType = resultTensorType;
|
||||
plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast<int64_t>(batchLaneCount);
|
||||
plan.batchLaneCount = batchLaneCount;
|
||||
plan.coveredLanes.assign(batchLaneCount, 0);
|
||||
|
||||
if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan)))
|
||||
return failure();
|
||||
|
||||
if (plan.coveredLaneCount == plan.batchLaneCount)
|
||||
return plan;
|
||||
|
||||
if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan)))
|
||||
return failure();
|
||||
|
||||
@@ -4181,7 +4351,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
||||
auto sourceBatch = cast<SpatComputeBatch>(sourceOp);
|
||||
SmallVector<Type, 4>& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch);
|
||||
SmallVector<Value, 4> initValues;
|
||||
|
||||
for (size_t resultIndex : group.resultIndices) {
|
||||
if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex])
|
||||
return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run");
|
||||
@@ -4197,7 +4366,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
||||
|
||||
SmallVector<int64_t, 8> logicalLanes;
|
||||
logicalLanes.reserve(run.size());
|
||||
|
||||
for (const MaterializationRunSlot& slot : run) {
|
||||
if (slot.peers.size() != 1)
|
||||
return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot");
|
||||
@@ -4215,34 +4383,34 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
||||
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
auto loop = buildNormalizedScfFor(
|
||||
state.rewriter,
|
||||
loc,
|
||||
lowerBound,
|
||||
upperBound,
|
||||
step,
|
||||
ValueRange(initValues),
|
||||
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
|
||||
state.rewriter,
|
||||
loc,
|
||||
lowerBound,
|
||||
upperBound,
|
||||
step,
|
||||
ValueRange(initValues),
|
||||
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state,
|
||||
targetClass,
|
||||
run.front().peers.front(),
|
||||
sourceLane,
|
||||
group.resultIndices,
|
||||
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
|
||||
if (failed(produced))
|
||||
return failure();
|
||||
FailureOr<SmallVector<Value, 4>> produced =
|
||||
cloneBatchBodyForLane(state,
|
||||
targetClass,
|
||||
run.front().peers.front(),
|
||||
sourceLane,
|
||||
group.resultIndices,
|
||||
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
|
||||
if (failed(produced))
|
||||
return failure();
|
||||
|
||||
yielded.reserve(produced->size());
|
||||
for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
|
||||
auto fragmentType = cast<RankedTensorType>(output.getType());
|
||||
Value acc = iterArgs[outputIndex];
|
||||
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
|
||||
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
|
||||
}
|
||||
return success();
|
||||
});
|
||||
yielded.reserve(produced->size());
|
||||
for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
|
||||
auto fragmentType = cast<RankedTensorType>(output.getType());
|
||||
Value acc = iterArgs[outputIndex];
|
||||
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
|
||||
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
|
||||
}
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
|
||||
@@ -4466,14 +4634,14 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
|
||||
}
|
||||
|
||||
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
|
||||
for (const auto& [key, consumers] : state.sameClassConsumers) {
|
||||
if (!consumers.contains(classId))
|
||||
continue;
|
||||
if (!sameProducerResult(key, producerKey))
|
||||
continue;
|
||||
if (containsProducerKey(key, producerKey) || containsProducerKey(producerKey, key))
|
||||
SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId};
|
||||
auto it = state.sameClassConsumerIndex.find(lookupKey);
|
||||
if (it == state.sameClassConsumerIndex.end())
|
||||
return false;
|
||||
|
||||
for (ProducerKey existing : it->second)
|
||||
if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -4488,6 +4656,7 @@ bool canCompactBatchClassRun(MaterializerState& state,
|
||||
ArrayRef<Value> outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front());
|
||||
|
||||
for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) {
|
||||
(void) ignored;
|
||||
for (const MaterializationRunSlot& slot : run) {
|
||||
if (slot.peers.empty())
|
||||
return false;
|
||||
@@ -4533,7 +4702,8 @@ Value createBatchClassRunSourceLane(MaterializerState& state,
|
||||
SmallVector<int64_t, 16> sourceLanes;
|
||||
sourceLanes.reserve(run.size() * targetClass.cpus.size());
|
||||
|
||||
for (const MaterializationRunSlot& slot : run) {
|
||||
for (auto [runSlotIndex, slot] : llvm::enumerate(run)) {
|
||||
(void) runSlotIndex;
|
||||
assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane");
|
||||
for (const ComputeInstance& peer : slot.peers)
|
||||
sourceLanes.push_back(peer.laneStart);
|
||||
@@ -4577,7 +4747,6 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
|
||||
plan.messages.targetCoreIds.reserve(messageCount);
|
||||
|
||||
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
|
||||
(void) slotIndex;
|
||||
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
|
||||
if (failed(checkedSourceCpu))
|
||||
@@ -4590,6 +4759,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
|
||||
return failure();
|
||||
plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
||||
}
|
||||
(void) slotIndex;
|
||||
}
|
||||
|
||||
plans.push_back(std::move(plan));
|
||||
@@ -4773,7 +4943,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) {
|
||||
LogicalResult materializeInstanceSlot(MaterializerState& state,
|
||||
const ComputeInstance& instance) {
|
||||
auto cpuIt = state.schedule.computeToCpuMap.find(instance);
|
||||
if (cpuIt == state.schedule.computeToCpuMap.end())
|
||||
return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance");
|
||||
@@ -4794,8 +4965,7 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
|
||||
return success();
|
||||
|
||||
if (isa<SpatComputeBatch>(instance.op)) {
|
||||
FailureOr<MaterializationRun> run =
|
||||
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
|
||||
FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
|
||||
|
||||
if (succeeded(run)) {
|
||||
if (!targetClass.isBatch)
|
||||
@@ -4924,6 +5094,7 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
|
||||
return failure();
|
||||
|
||||
LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody());
|
||||
(void) _;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -14,20 +13,14 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -51,83 +44,6 @@ using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
using spatial::getProducerValueRef;
|
||||
|
||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||
|
||||
class ScopedMergePhaseTimer {
|
||||
public:
|
||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||
if (enabled)
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
~ScopedMergePhaseTimer() {
|
||||
if (!enabled)
|
||||
return;
|
||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||
}
|
||||
|
||||
private:
|
||||
bool enabled = false;
|
||||
std::string phase;
|
||||
std::chrono::steady_clock::time_point start;
|
||||
};
|
||||
|
||||
struct MergeIrCounts {
|
||||
uint64_t topLevelComputeCount = 0;
|
||||
uint64_t topLevelComputeBatchCount = 0;
|
||||
uint64_t scalarChannelSendCount = 0;
|
||||
uint64_t scalarChannelReceiveCount = 0;
|
||||
uint64_t wvmmCount = 0;
|
||||
uint64_t vaddCount = 0;
|
||||
uint64_t scfForCount = 0;
|
||||
};
|
||||
|
||||
MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
|
||||
MergeIrCounts counts;
|
||||
|
||||
auto countComputeBodyOps = [&](Operation* op) {
|
||||
op->walk([&](Operation* nestedOp) {
|
||||
if (isa<spatial::SpatChannelSendOp>(nestedOp))
|
||||
++counts.scalarChannelSendCount;
|
||||
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
|
||||
++counts.scalarChannelReceiveCount;
|
||||
else if (isa<spatial::SpatVMMOp>(nestedOp))
|
||||
++counts.wvmmCount;
|
||||
else if (isa<spatial::SpatVAddOp>(nestedOp))
|
||||
++counts.vaddCount;
|
||||
else if (isa<scf::ForOp>(nestedOp))
|
||||
++counts.scfForCount;
|
||||
});
|
||||
};
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>()) {
|
||||
++counts.topLevelComputeCount;
|
||||
countComputeBodyOps(compute.getOperation());
|
||||
}
|
||||
|
||||
for (auto batch : funcOp.getOps<SpatComputeBatch>()) {
|
||||
++counts.topLevelComputeBatchCount;
|
||||
countComputeBodyOps(batch.getOperation());
|
||||
}
|
||||
|
||||
return counts;
|
||||
}
|
||||
|
||||
void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
||||
if (!isMergeProfilingEnabled())
|
||||
return;
|
||||
|
||||
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
||||
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
||||
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
||||
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
||||
}
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||
@@ -138,16 +54,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct ComputeMotifInfo {
|
||||
uint64_t instructionCount = 0;
|
||||
uint64_t weightedVmmCount = 0;
|
||||
};
|
||||
|
||||
void appendUnique(SmallVector<size_t>& values, size_t value) {
|
||||
if (!llvm::is_contained(values, value))
|
||||
values.push_back(value);
|
||||
}
|
||||
|
||||
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||
if (!compute->hasOneUse())
|
||||
return false;
|
||||
@@ -266,212 +172,6 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
}
|
||||
}
|
||||
|
||||
void emitMotifProfile(func::FuncOp funcOp) {
|
||||
if (!std::getenv("DCP_MOTIF_PROFILE"))
|
||||
return;
|
||||
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
DenseMap<SpatCompute, size_t> computeToIndex;
|
||||
computeToIndex.reserve(computes.size());
|
||||
for (auto [index, compute] : llvm::enumerate(computes))
|
||||
computeToIndex[compute] = index;
|
||||
|
||||
SmallVector<ComputeMotifInfo> computeInfos(computes.size());
|
||||
SmallVector<SmallVector<size_t>> parents(computes.size());
|
||||
SmallVector<SmallVector<size_t>> children(computes.size());
|
||||
|
||||
uint64_t weightedVmmNodeCount = 0;
|
||||
uint64_t weightedVmmOpCount = 0;
|
||||
uint64_t edgeCount = 0;
|
||||
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
ComputeMotifInfo& info = computeInfos[index];
|
||||
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
|
||||
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
|
||||
if (info.weightedVmmCount > 0) {
|
||||
weightedVmmNodeCount++;
|
||||
weightedVmmOpCount += info.weightedVmmCount;
|
||||
}
|
||||
|
||||
for (Value input : compute.getInputs()) {
|
||||
auto parent = dyn_cast<SpatCompute>(input.getDefiningOp());
|
||||
if (!parent || parent == compute)
|
||||
continue;
|
||||
auto parentIt = computeToIndex.find(parent);
|
||||
if (parentIt == computeToIndex.end())
|
||||
continue;
|
||||
|
||||
size_t parentIndex = parentIt->second;
|
||||
size_t oldParentCount = parents[index].size();
|
||||
appendUnique(parents[index], parentIndex);
|
||||
if (parents[index].size() != oldParentCount) {
|
||||
appendUnique(children[parentIndex], index);
|
||||
edgeCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t maxFanIn = 0;
|
||||
uint64_t maxFanOut = 0;
|
||||
uint64_t fanIn16 = 0;
|
||||
uint64_t fanIn64 = 0;
|
||||
uint64_t fanIn256 = 0;
|
||||
uint64_t fanOut16 = 0;
|
||||
uint64_t fanOut64 = 0;
|
||||
uint64_t fanOut256 = 0;
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
uint64_t fanIn = parents[index].size();
|
||||
uint64_t fanOut = children[index].size();
|
||||
maxFanIn = std::max(maxFanIn, fanIn);
|
||||
maxFanOut = std::max(maxFanOut, fanOut);
|
||||
fanIn16 += fanIn >= 16;
|
||||
fanIn64 += fanIn >= 64;
|
||||
fanIn256 += fanIn >= 256;
|
||||
fanOut16 += fanOut >= 16;
|
||||
fanOut64 += fanOut >= 64;
|
||||
fanOut256 += fanOut >= 256;
|
||||
}
|
||||
|
||||
uint64_t serialChainCount = 0;
|
||||
uint64_t serialChainNodeCount = 0;
|
||||
uint64_t maxSerialChain = 0;
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
if (parents[index].size() == 1 && children[parents[index][0]].size() == 1)
|
||||
continue;
|
||||
|
||||
uint64_t chainLength = 1;
|
||||
size_t current = index;
|
||||
while (children[current].size() == 1) {
|
||||
size_t child = children[current][0];
|
||||
if (parents[child].size() != 1)
|
||||
break;
|
||||
chainLength++;
|
||||
current = child;
|
||||
}
|
||||
|
||||
if (chainLength >= 2) {
|
||||
serialChainCount++;
|
||||
serialChainNodeCount += chainLength;
|
||||
maxSerialChain = std::max(maxSerialChain, chainLength);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<size_t> incomingEdgeCount;
|
||||
incomingEdgeCount.reserve(parents.size());
|
||||
for (ArrayRef<size_t> parentList : parents)
|
||||
incomingEdgeCount.push_back(parentList.size());
|
||||
|
||||
SmallVector<uint64_t> level(computes.size(), 0);
|
||||
SmallVector<size_t> readyNodes;
|
||||
readyNodes.reserve(computes.size());
|
||||
for (size_t index = 0; index < computes.size(); ++index)
|
||||
if (incomingEdgeCount[index] == 0)
|
||||
readyNodes.push_back(index);
|
||||
|
||||
size_t readyIndex = 0;
|
||||
while (readyIndex != readyNodes.size()) {
|
||||
size_t current = readyNodes[readyIndex++];
|
||||
for (size_t child : children[current]) {
|
||||
level[child] = std::max(level[child], level[current] + 1);
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push_back(child);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<uint64_t> weightedVmmNodesByLevel;
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
if (computeInfos[index].weightedVmmCount == 0)
|
||||
continue;
|
||||
if (weightedVmmNodesByLevel.size() <= level[index])
|
||||
weightedVmmNodesByLevel.resize(level[index] + 1, 0);
|
||||
weightedVmmNodesByLevel[level[index]]++;
|
||||
}
|
||||
|
||||
uint64_t maxWeightedVmmLevel = 0;
|
||||
uint64_t wideWeightedVmmLevels64 = 0;
|
||||
uint64_t wideWeightedVmmLevels256 = 0;
|
||||
for (uint64_t count : weightedVmmNodesByLevel) {
|
||||
maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count);
|
||||
wideWeightedVmmLevels64 += count >= 64;
|
||||
wideWeightedVmmLevels256 += count >= 256;
|
||||
}
|
||||
|
||||
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
|
||||
SmallVector<ShapeKey> weightedVmmShapeKeys;
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
const ComputeMotifInfo& info = computeInfos[index];
|
||||
if (info.weightedVmmCount == 0)
|
||||
continue;
|
||||
weightedVmmShapeKeys.push_back({info.instructionCount,
|
||||
info.weightedVmmCount,
|
||||
static_cast<uint64_t>(compute.getWeights().size()),
|
||||
static_cast<uint64_t>(compute.getInputs().size()),
|
||||
static_cast<uint64_t>(parents[index].size()),
|
||||
static_cast<uint64_t>(children[index].size())});
|
||||
}
|
||||
|
||||
llvm::sort(weightedVmmShapeKeys);
|
||||
SmallVector<std::pair<uint64_t, ShapeKey>> weightedVmmShapeCounts;
|
||||
for (size_t index = 0; index < weightedVmmShapeKeys.size();) {
|
||||
size_t next = index + 1;
|
||||
while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index])
|
||||
next++;
|
||||
weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]});
|
||||
index = next;
|
||||
}
|
||||
llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first > rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} "
|
||||
"serialChains={4} serialChainNodes={5} maxSerialChain={6} "
|
||||
"maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} "
|
||||
"fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n",
|
||||
computes.size(),
|
||||
edgeCount,
|
||||
weightedVmmNodeCount,
|
||||
weightedVmmOpCount,
|
||||
serialChainCount,
|
||||
serialChainNodeCount,
|
||||
maxSerialChain,
|
||||
maxFanIn,
|
||||
maxFanOut,
|
||||
fanIn16,
|
||||
fanIn64,
|
||||
fanIn256,
|
||||
fanOut16,
|
||||
fanOut64,
|
||||
fanOut256,
|
||||
readyNodes.size());
|
||||
|
||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} "
|
||||
"shapeGroups={4}\n",
|
||||
weightedVmmNodesByLevel.size(),
|
||||
maxWeightedVmmLevel,
|
||||
wideWeightedVmmLevels64,
|
||||
wideWeightedVmmLevels256,
|
||||
weightedVmmShapeCounts.size());
|
||||
|
||||
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
|
||||
auto [count, shape] = weightedVmmShapeCounts[rank];
|
||||
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
|
||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
|
||||
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
|
||||
rank,
|
||||
count,
|
||||
insts,
|
||||
vmmOps,
|
||||
weights,
|
||||
inputs,
|
||||
fanIn,
|
||||
fanOut);
|
||||
}
|
||||
}
|
||||
|
||||
void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) {
|
||||
std::fstream file = openReportFile(name);
|
||||
if (!file.is_open())
|
||||
@@ -628,44 +328,27 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
{
|
||||
ScopedMergePhaseTimer timer("trivial-serial-merge");
|
||||
mergeTriviallyConnectedComputes(func);
|
||||
}
|
||||
if (std::getenv("DCP_MOTIF_PROFILE"))
|
||||
emitMotifProfile(func);
|
||||
mergeTriviallyConnectedComputes(func);
|
||||
|
||||
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
||||
{
|
||||
ScopedMergePhaseTimer timer("scheduling-analysis");
|
||||
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("schedule-materialization");
|
||||
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
emitMergeIrCounts("after-materialization", func);
|
||||
|
||||
{
|
||||
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (failed(verifySpatialCommunicationInvariants(func))) {
|
||||
func.emitOpError("merged Spatial communication invariant verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
emitMergeIrCounts("final-post-merge", func);
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
|
||||
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (failed(verifySpatialCommunicationInvariants(func))) {
|
||||
func.emitOpError("merged Spatial communication invariant verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
|
||||
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user