much faster MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
+322
-151
@@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
@@ -117,6 +118,66 @@ struct ProducerKeyInfo {
|
|||||||
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
|
static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct SameClassConsumerLookupKey {
|
||||||
|
Operation* sourceOp = nullptr;
|
||||||
|
size_t resultIndex = 0;
|
||||||
|
ClassId classId = 0;
|
||||||
|
|
||||||
|
bool operator==(const SameClassConsumerLookupKey& other) const {
|
||||||
|
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SameClassConsumerLookupKeyInfo {
|
||||||
|
static SameClassConsumerLookupKey getEmptyKey() {
|
||||||
|
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||||
|
std::numeric_limits<ClassId>::max()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static SameClassConsumerLookupKey getTombstoneKey() {
|
||||||
|
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||||
|
std::numeric_limits<ClassId>::max()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned getHashValue(const SameClassConsumerLookupKey& key) {
|
||||||
|
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WholeBatchAssemblyLookupKey {
|
||||||
|
Operation* sourceOp = nullptr;
|
||||||
|
size_t resultIndex = 0;
|
||||||
|
ClassId classId = 0;
|
||||||
|
|
||||||
|
bool operator==(const WholeBatchAssemblyLookupKey& other) const {
|
||||||
|
return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WholeBatchAssemblyLookupKeyInfo {
|
||||||
|
static WholeBatchAssemblyLookupKey getEmptyKey() {
|
||||||
|
return {llvm::DenseMapInfo<Operation*>::getEmptyKey(), std::numeric_limits<size_t>::max(),
|
||||||
|
std::numeric_limits<ClassId>::max()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static WholeBatchAssemblyLookupKey getTombstoneKey() {
|
||||||
|
return {llvm::DenseMapInfo<Operation*>::getTombstoneKey(), std::numeric_limits<size_t>::max(),
|
||||||
|
std::numeric_limits<ClassId>::max()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) {
|
||||||
|
return llvm::hash_combine(llvm::DenseMapInfo<Operation*>::getHashValue(key.sourceOp), key.resultIndex, key.classId);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
using ClassSlotKey = std::pair<ClassId, SlotId>;
|
using ClassSlotKey = std::pair<ClassId, SlotId>;
|
||||||
|
|
||||||
struct MaterializedClass {
|
struct MaterializedClass {
|
||||||
@@ -270,9 +331,36 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
|
|||||||
|
|
||||||
class AvailableValueStore {
|
class AvailableValueStore {
|
||||||
public:
|
public:
|
||||||
void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; }
|
struct ExactBatchFragmentRecord {
|
||||||
|
ProducerKey key;
|
||||||
|
Value value;
|
||||||
|
};
|
||||||
|
|
||||||
void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); }
|
void record(ProducerKey key, ClassId classId, Value value) {
|
||||||
|
exactValues[key][classId] = value;
|
||||||
|
|
||||||
|
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||||
|
if (!batch || key.instance.laneCount == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId};
|
||||||
|
SmallVector<ExactBatchFragmentRecord, 16>& bucket = exactBatchFragmentsByProducerResultClass[lookupKey];
|
||||||
|
for (ExactBatchFragmentRecord& record : bucket) {
|
||||||
|
if (!(record.key == key))
|
||||||
|
continue;
|
||||||
|
record.value = value;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bucket.push_back({key, value});
|
||||||
|
}
|
||||||
|
|
||||||
|
void recordPackedRun(PackedScalarRunValue run) {
|
||||||
|
size_t runIndex = packedScalarRuns.size();
|
||||||
|
packedScalarRuns.push_back(std::move(run));
|
||||||
|
const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex];
|
||||||
|
WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass};
|
||||||
|
packedRunsByProducerResultClass[lookupKey].push_back(runIndex);
|
||||||
|
}
|
||||||
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
|
void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); }
|
||||||
|
|
||||||
std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const;
|
std::optional<Value> lookupExact(ProducerKey key, ClassId classId) const;
|
||||||
@@ -280,7 +368,21 @@ public:
|
|||||||
std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
|
std::optional<Value> lookup(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||||
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
|
IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId);
|
||||||
|
|
||||||
SmallVectorImpl<PackedScalarRunValue>& getPackedScalarRuns() { return packedScalarRuns; }
|
ArrayRef<size_t> getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||||
|
auto it = packedRunsByProducerResultClass.find(key);
|
||||||
|
if (it == packedRunsByProducerResultClass.end())
|
||||||
|
return {};
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<ExactBatchFragmentRecord> getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const {
|
||||||
|
auto it = exactBatchFragmentsByProducerResultClass.find(key);
|
||||||
|
if (it == exactBatchFragmentsByProducerResultClass.end())
|
||||||
|
return {};
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
|
std::optional<Value> lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId);
|
||||||
@@ -288,6 +390,10 @@ private:
|
|||||||
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues;
|
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> exactValues;
|
||||||
SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
|
SmallVector<PackedScalarRunValue, 8> packedScalarRuns;
|
||||||
SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
|
SmallVector<IndexedBatchRunValue, 8> indexedBatchRuns;
|
||||||
|
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<ExactBatchFragmentRecord, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||||
|
exactBatchFragmentsByProducerResultClass;
|
||||||
|
DenseMap<WholeBatchAssemblyLookupKey, SmallVector<size_t, 16>, WholeBatchAssemblyLookupKeyInfo>
|
||||||
|
packedRunsByProducerResultClass;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MaterializerState {
|
struct MaterializerState {
|
||||||
@@ -296,7 +402,6 @@ struct MaterializerState {
|
|||||||
IRRewriter rewriter;
|
IRRewriter rewriter;
|
||||||
OperationFolder constantFolder;
|
OperationFolder constantFolder;
|
||||||
int64_t& nextChannelId;
|
int64_t& nextChannelId;
|
||||||
|
|
||||||
SmallVector<MaterializedClass, 8> classes;
|
SmallVector<MaterializedClass, 8> classes;
|
||||||
DenseMap<CpuId, ClassId> cpuToClass;
|
DenseMap<CpuId, ClassId> cpuToClass;
|
||||||
DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
|
DenseMap<CpuId, SmallVector<ComputeInstance, 32>> logicalInstancesByCpu;
|
||||||
@@ -305,7 +410,8 @@ struct MaterializerState {
|
|||||||
DenseSet<ClassSlotKey> materializedLogicalSlots;
|
DenseSet<ClassSlotKey> materializedLogicalSlots;
|
||||||
|
|
||||||
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||||
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> sameClassConsumers;
|
DenseMap<SameClassConsumerLookupKey, SmallVector<ProducerKey, 4>, SameClassConsumerLookupKeyInfo>
|
||||||
|
sameClassConsumerIndex;
|
||||||
DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches;
|
DenseMap<ProjectedBatchInputKey, AffineProjectedInputSliceMatch, ProjectedBatchInputKeyInfo> projectedInputMatches;
|
||||||
DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
|
DenseSet<ProjectedBatchInputKey, ProjectedBatchInputKeyInfo> nonProjectedInputs;
|
||||||
DenseMap<Value, bool> liveExternalUseCache;
|
DenseMap<Value, bool> liveExternalUseCache;
|
||||||
@@ -317,7 +423,9 @@ struct MaterializerState {
|
|||||||
DenseMap<Value, Value> hostReplacements;
|
DenseMap<Value, Value> hostReplacements;
|
||||||
DenseSet<Operation*> oldComputeOps;
|
DenseSet<Operation*> oldComputeOps;
|
||||||
|
|
||||||
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
MaterializerState(func::FuncOp func,
|
||||||
|
const MergeScheduleResult& schedule,
|
||||||
|
int64_t& nextChannelId)
|
||||||
: func(func),
|
: func(func),
|
||||||
schedule(schedule),
|
schedule(schedule),
|
||||||
rewriter(func.getContext()),
|
rewriter(func.getContext()),
|
||||||
@@ -428,6 +536,14 @@ std::optional<ProducerKey> getContiguousProducerRangeForKeys(ArrayRef<ProducerKe
|
|||||||
return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex);
|
return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) {
|
||||||
|
return {sourceOp, resultIndex, classId};
|
||||||
|
}
|
||||||
|
|
||||||
|
WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) {
|
||||||
|
return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId);
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<RankedTensorType> getPackedBatchTensorType(Type laneType, size_t laneCount) {
|
FailureOr<RankedTensorType> getPackedBatchTensorType(Type laneType, size_t laneCount) {
|
||||||
auto tensorType = dyn_cast<RankedTensorType>(laneType);
|
auto tensorType = dyn_cast<RankedTensorType>(laneType);
|
||||||
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
|
if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0)
|
||||||
@@ -1172,14 +1288,12 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
|
|||||||
|
|
||||||
std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) {
|
std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) {
|
||||||
for (PackedScalarRunValue& run : packedScalarRuns) {
|
for (PackedScalarRunValue& run : packedScalarRuns) {
|
||||||
if (run.targetClass != classId)
|
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||||
continue;
|
|
||||||
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
|
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
|
||||||
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys);
|
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||||
if (!slotKey || !containsProducerKey(*slotKey, key))
|
if (!contiguousKey || !containsProducerKey(*contiguousKey, key))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
|
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
|
||||||
@@ -1197,12 +1311,13 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
|||||||
Value slotPacked =
|
Value slotPacked =
|
||||||
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
|
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
|
||||||
|
|
||||||
if (*slotKey == key) {
|
if (*contiguousKey == key) {
|
||||||
record(key, classId, slotPacked);
|
record(key, classId, slotPacked);
|
||||||
return slotPacked;
|
return slotPacked;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<Value> sliced = extractPackedProducerSlice(state, materializedClass, *slotKey, slotPacked, key);
|
std::optional<Value> sliced =
|
||||||
|
extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key);
|
||||||
if (!sliced)
|
if (!sliced)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
@@ -1216,57 +1331,45 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
|||||||
|
|
||||||
IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) {
|
IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) {
|
||||||
for (IndexedBatchRunValue& run : indexedBatchRuns) {
|
for (IndexedBatchRunValue& run : indexedBatchRuns) {
|
||||||
if (run.targetClass != classId)
|
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||||
continue;
|
continue;
|
||||||
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
for (const PackedScalarRunSlot& slot : run.slots) {
|
||||||
continue;
|
if (!llvm::is_contained(slot.keys, key))
|
||||||
|
continue;
|
||||||
for (const PackedScalarRunSlot& slot : run.slots)
|
return &run;
|
||||||
if (llvm::is_contained(slot.keys, key))
|
}
|
||||||
return &run;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) {
|
std::optional<Value> AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) {
|
||||||
if (std::optional<Value> exact = lookupExact(key, classId))
|
|
||||||
|
if (std::optional<Value> exact = lookupExact(key, classId)) {
|
||||||
return exact;
|
return exact;
|
||||||
|
}
|
||||||
|
|
||||||
if (std::optional<Value> packedRunValue = lookupPackedRun(state, key, classId))
|
if (std::optional<Value> packedRunValue = lookupPackedRun(state, key, classId))
|
||||||
return packedRunValue;
|
return packedRunValue;
|
||||||
|
|
||||||
MaterializedClass& materializedClass = state.classes[classId];
|
MaterializedClass& materializedClass = state.classes[classId];
|
||||||
|
|
||||||
ProducerKey containingKey;
|
for (const auto& [candidateKey, classValues] : exactValues) {
|
||||||
Value containingValue;
|
if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key))
|
||||||
bool foundContainingValue = false;
|
|
||||||
|
|
||||||
for (auto& entry : exactValues) {
|
|
||||||
ProducerKey candidateKey = entry.first;
|
|
||||||
if (!containsProducerKey(candidateKey, key))
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto valueIt = entry.second.find(classId);
|
auto valueIt = classValues.find(classId);
|
||||||
if (valueIt == entry.second.end())
|
if (valueIt == classValues.end())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
containingKey = candidateKey;
|
std::optional<Value> slice =
|
||||||
containingValue = valueIt->second;
|
extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key);
|
||||||
foundContainingValue = true;
|
if (!slice)
|
||||||
break;
|
return std::nullopt;
|
||||||
|
|
||||||
|
record(key, classId, *slice);
|
||||||
|
return *slice;
|
||||||
}
|
}
|
||||||
|
return std::nullopt;
|
||||||
if (!foundContainingValue)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
std::optional<Value> slice =
|
|
||||||
extractPackedProducerSlice(state, materializedClass, containingKey, containingValue, key);
|
|
||||||
if (!slice)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
record(key, classId, *slice);
|
|
||||||
return *slice;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||||
@@ -1389,13 +1492,13 @@ Value createIndexedIndexValue(MaterializerState& state,
|
|||||||
bool allowExhaustiveTiledSearch) {
|
bool allowExhaustiveTiledSearch) {
|
||||||
assert(!values.empty() && "expected at least one indexed value");
|
assert(!values.empty() && "expected at least one indexed value");
|
||||||
|
|
||||||
if (allEqual(values))
|
if (allEqual(values)) {
|
||||||
return getOrCreateIndexConstant(state.constantFolder, anchor, values.front());
|
return getOrCreateIndexConstant(state.constantFolder, anchor, values.front());
|
||||||
|
}
|
||||||
|
|
||||||
if (std::optional<IndexedIndexPattern> pattern =
|
if (std::optional<IndexedIndexPattern> pattern =
|
||||||
getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch))
|
getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch))
|
||||||
return createAffineIndexValue(state, *pattern, index, loc);
|
return createAffineIndexValue(state, *pattern, index, loc);
|
||||||
|
|
||||||
Value table = createIndexTensorConstant(state, anchor, values);
|
Value table = createIndexTensorConstant(state, anchor, values);
|
||||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
||||||
}
|
}
|
||||||
@@ -1578,7 +1681,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
|||||||
|
|
||||||
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
|
ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second);
|
||||||
if (sourceClass == targetClass) {
|
if (sourceClass == targetClass) {
|
||||||
state.sameClassConsumers[producerKey].insert(targetClass);
|
SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass};
|
||||||
|
SmallVector<ProducerKey, 4>& bucket = state.sameClassConsumerIndex[lookupKey];
|
||||||
|
if (!llvm::is_contained(bucket, producerKey))
|
||||||
|
bucket.push_back(producerKey);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2899,11 +3005,6 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WholeBatchAssemblyRange {
|
|
||||||
uint32_t laneStart = 0;
|
|
||||||
uint32_t laneCount = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DirectWholeBatchFragment {
|
struct DirectWholeBatchFragment {
|
||||||
ProducerKey key;
|
ProducerKey key;
|
||||||
Value fragment;
|
Value fragment;
|
||||||
@@ -2933,31 +3034,60 @@ struct WholeBatchFragmentGroup {
|
|||||||
struct WholeBatchAssemblyPlan {
|
struct WholeBatchAssemblyPlan {
|
||||||
RankedTensorType resultType;
|
RankedTensorType resultType;
|
||||||
int64_t rowsPerLane = 0;
|
int64_t rowsPerLane = 0;
|
||||||
|
uint32_t batchLaneCount = 0;
|
||||||
|
uint32_t coveredLaneCount = 0;
|
||||||
|
|
||||||
SmallVector<WholeBatchAssemblyRange, 16> coveredRanges;
|
SmallVector<uint8_t, 64> coveredLanes;
|
||||||
SmallVector<PackedScalarRunValue*, 8> packedRuns;
|
SmallVector<PackedScalarRunValue*, 8> packedRuns;
|
||||||
SmallVector<DirectWholeBatchFragment, 16> directFragments;
|
SmallVector<DirectWholeBatchFragment, 16> directFragments;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool wholeBatchRangeOverlaps(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t laneStart, uint32_t laneCount) {
|
bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) {
|
||||||
uint32_t laneEnd = laneStart + laneCount;
|
return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0;
|
||||||
for (WholeBatchAssemblyRange range : ranges) {
|
|
||||||
uint32_t rangeEnd = range.laneStart + range.laneCount;
|
|
||||||
if (laneStart < rangeEnd && range.laneStart < laneEnd)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool wholeBatchLaneCovered(ArrayRef<WholeBatchAssemblyRange> ranges, uint32_t lane) {
|
bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
|
||||||
for (WholeBatchAssemblyRange range : ranges)
|
if (laneCount == 0)
|
||||||
if (range.laneStart <= lane && lane < range.laneStart + range.laneCount)
|
return false;
|
||||||
|
if (laneStart >= plan.coveredLanes.size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, plan.coveredLanes.size());
|
||||||
|
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
|
||||||
|
if (plan.coveredLanes[lane] != 0)
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
|
void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) {
|
||||||
plan.coveredRanges.push_back({laneStart, laneCount});
|
assert(laneCount != 0 && "cannot cover an empty whole-batch range");
|
||||||
|
assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds");
|
||||||
|
|
||||||
|
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) {
|
||||||
|
if (plan.coveredLanes[lane] != 0)
|
||||||
|
continue;
|
||||||
|
plan.coveredLanes[lane] = 1;
|
||||||
|
++plan.coveredLaneCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool localLaneRangeOverlaps(ArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
|
||||||
|
if (laneCount == 0)
|
||||||
|
return false;
|
||||||
|
if (laneStart >= covered.size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
uint32_t laneEnd = std::min<uint32_t>(laneStart + laneCount, covered.size());
|
||||||
|
for (uint32_t lane = laneStart; lane < laneEnd; ++lane)
|
||||||
|
if (covered[lane] != 0)
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void markLocalLaneRangeCovered(MutableArrayRef<uint8_t> covered, uint32_t laneStart, uint32_t laneCount) {
|
||||||
|
assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds");
|
||||||
|
for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane)
|
||||||
|
covered[lane] = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -3118,13 +3248,14 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
|
|||||||
MaterializedClass& targetClass,
|
MaterializedClass& targetClass,
|
||||||
ProducerKey key,
|
ProducerKey key,
|
||||||
WholeBatchAssemblyPlan& plan) {
|
WholeBatchAssemblyPlan& plan) {
|
||||||
for (PackedScalarRunValue& run : state.availableValues.getPackedScalarRuns()) {
|
WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
|
||||||
if (run.targetClass != targetClass.id)
|
ArrayRef<size_t> runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey);
|
||||||
continue;
|
|
||||||
if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
SmallVector<WholeBatchAssemblyRange, 16> runRanges;
|
for (size_t runIndex : runIndices) {
|
||||||
|
PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex);
|
||||||
|
|
||||||
|
SmallVector<ProducerKey, 16> runKeys;
|
||||||
|
SmallVector<uint8_t, 64> runCoveredLanes(plan.batchLaneCount, 0);
|
||||||
|
|
||||||
for (const PackedScalarRunSlot& slot : run.slots) {
|
for (const PackedScalarRunSlot& slot : run.slots) {
|
||||||
for (ProducerKey fragmentKey : slot.keys) {
|
for (ProducerKey fragmentKey : slot.keys) {
|
||||||
@@ -3134,23 +3265,24 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
|
|||||||
if (fragmentKey.instance.laneCount == 0)
|
if (fragmentKey.instance.laneCount == 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (wholeBatchRangeOverlaps(plan.coveredRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (wholeBatchRangeOverlaps(runRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
runRanges.push_back({fragmentKey.instance.laneStart, fragmentKey.instance.laneCount});
|
markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount);
|
||||||
|
runKeys.push_back(fragmentKey);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (runRanges.empty())
|
if (runKeys.empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
plan.packedRuns.push_back(&run);
|
plan.packedRuns.push_back(&run);
|
||||||
|
|
||||||
for (WholeBatchAssemblyRange range : runRanges)
|
for (ProducerKey runKey : runKeys)
|
||||||
recordWholeBatchCoverage(plan, range.laneStart, range.laneCount);
|
recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@@ -3161,44 +3293,77 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state,
|
|||||||
SpatComputeBatch batch,
|
SpatComputeBatch batch,
|
||||||
ProducerKey key,
|
ProducerKey key,
|
||||||
WholeBatchAssemblyPlan& plan) {
|
WholeBatchAssemblyPlan& plan) {
|
||||||
|
struct CandidateFragment {
|
||||||
|
ProducerKey key;
|
||||||
|
Value value;
|
||||||
|
};
|
||||||
|
|
||||||
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
|
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
|
||||||
uint32_t lane = 0;
|
if (plan.coveredLaneCount == plan.batchLaneCount) {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
while (lane < batchLaneCount) {
|
WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id);
|
||||||
if (wholeBatchLaneCovered(plan.coveredRanges, lane)) {
|
ArrayRef<AvailableValueStore::ExactBatchFragmentRecord> indexedFragments =
|
||||||
++lane;
|
state.availableValues.getExactFragmentsForWholeBatch(lookupKey);
|
||||||
|
|
||||||
|
SmallVector<CandidateFragment, 16> candidates;
|
||||||
|
candidates.reserve(indexedFragments.size());
|
||||||
|
for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) {
|
||||||
|
ProducerKey candidateKey = record.key;
|
||||||
|
if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex
|
||||||
|
|| candidateKey.instance.laneCount == 0)
|
||||||
continue;
|
continue;
|
||||||
|
if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto fragmentType = dyn_cast<RankedTensorType>(record.value.getType());
|
||||||
|
if (!fragmentType)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(candidateKey.instance.laneCount);
|
||||||
|
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
candidates.push_back({candidateKey, record.value});
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) {
|
||||||
|
if (lhs.key.instance.laneStart != rhs.key.instance.laneStart)
|
||||||
|
return lhs.key.instance.laneStart < rhs.key.instance.laneStart;
|
||||||
|
return lhs.key.instance.laneCount > rhs.key.instance.laneCount;
|
||||||
|
});
|
||||||
|
|
||||||
|
size_t candidateCursor = 0;
|
||||||
|
uint32_t lane = 0;
|
||||||
|
while (lane < batchLaneCount) {
|
||||||
|
while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) {
|
||||||
|
++lane;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool foundFragment = false;
|
if (lane >= batchLaneCount)
|
||||||
|
|
||||||
for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) {
|
|
||||||
if (wholeBatchRangeOverlaps(plan.coveredRanges, lane, laneCount))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex);
|
|
||||||
std::optional<Value> fragment = state.availableValues.lookupExact(candidate, targetClass.id);
|
|
||||||
if (!fragment)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto fragmentType = dyn_cast<RankedTensorType>(fragment->getType());
|
|
||||||
if (!fragmentType)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
int64_t expectedRows = plan.rowsPerLane * static_cast<int64_t>(laneCount);
|
|
||||||
if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
plan.directFragments.push_back({candidate, *fragment});
|
|
||||||
recordWholeBatchCoverage(plan, lane, laneCount);
|
|
||||||
|
|
||||||
lane += laneCount;
|
|
||||||
foundFragment = true;
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane)
|
||||||
|
++candidateCursor;
|
||||||
|
|
||||||
|
size_t candidateIndex = candidateCursor;
|
||||||
|
const CandidateFragment* best = nullptr;
|
||||||
|
while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) {
|
||||||
|
const CandidateFragment& candidate = candidates[candidateIndex];
|
||||||
|
if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) {
|
||||||
|
best = &candidate;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++candidateIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!foundFragment)
|
if (!best)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
plan.directFragments.push_back({best->key, best->value});
|
||||||
|
recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount);
|
||||||
|
lane += best->key.instance.laneCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@@ -3291,11 +3456,11 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) {
|
for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) {
|
||||||
std::optional<ProducerKey> slotKey = getContiguousProducerRangeForKeys(slot.keys);
|
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||||
if (!slotKey)
|
if (!contiguousKey)
|
||||||
return failure();
|
return failure();
|
||||||
groupIt->slotIndices.push_back(slotIndex);
|
groupIt->slotIndices.push_back(slotIndex);
|
||||||
groupIt->outputOffsets.push_back(static_cast<int64_t>(slotKey->instance.laneStart) * plan.rowsPerLane);
|
groupIt->outputOffsets.push_back(static_cast<int64_t>(contiguousKey->instance.laneStart) * plan.rowsPerLane);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3409,10 +3574,15 @@ FailureOr<WholeBatchAssemblyPlan> buildWholeBatchAssemblyPlan(MaterializerState&
|
|||||||
WholeBatchAssemblyPlan plan;
|
WholeBatchAssemblyPlan plan;
|
||||||
plan.resultType = resultTensorType;
|
plan.resultType = resultTensorType;
|
||||||
plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast<int64_t>(batchLaneCount);
|
plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast<int64_t>(batchLaneCount);
|
||||||
|
plan.batchLaneCount = batchLaneCount;
|
||||||
|
plan.coveredLanes.assign(batchLaneCount, 0);
|
||||||
|
|
||||||
if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan)))
|
if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
if (plan.coveredLaneCount == plan.batchLaneCount)
|
||||||
|
return plan;
|
||||||
|
|
||||||
if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan)))
|
if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -4181,7 +4351,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
|||||||
auto sourceBatch = cast<SpatComputeBatch>(sourceOp);
|
auto sourceBatch = cast<SpatComputeBatch>(sourceOp);
|
||||||
SmallVector<Type, 4>& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch);
|
SmallVector<Type, 4>& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch);
|
||||||
SmallVector<Value, 4> initValues;
|
SmallVector<Value, 4> initValues;
|
||||||
|
|
||||||
for (size_t resultIndex : group.resultIndices) {
|
for (size_t resultIndex : group.resultIndices) {
|
||||||
if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex])
|
if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex])
|
||||||
return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run");
|
return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run");
|
||||||
@@ -4197,7 +4366,6 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
|||||||
|
|
||||||
SmallVector<int64_t, 8> logicalLanes;
|
SmallVector<int64_t, 8> logicalLanes;
|
||||||
logicalLanes.reserve(run.size());
|
logicalLanes.reserve(run.size());
|
||||||
|
|
||||||
for (const MaterializationRunSlot& slot : run) {
|
for (const MaterializationRunSlot& slot : run) {
|
||||||
if (slot.peers.size() != 1)
|
if (slot.peers.size() != 1)
|
||||||
return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot");
|
return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot");
|
||||||
@@ -4215,34 +4383,34 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
|||||||
|
|
||||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
auto loop = buildNormalizedScfFor(
|
auto loop = buildNormalizedScfFor(
|
||||||
state.rewriter,
|
state.rewriter,
|
||||||
loc,
|
loc,
|
||||||
lowerBound,
|
lowerBound,
|
||||||
upperBound,
|
upperBound,
|
||||||
step,
|
step,
|
||||||
ValueRange(initValues),
|
ValueRange(initValues),
|
||||||
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
[&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||||
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
|
Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc);
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>> produced =
|
FailureOr<SmallVector<Value, 4>> produced =
|
||||||
cloneBatchBodyForLane(state,
|
cloneBatchBodyForLane(state,
|
||||||
targetClass,
|
targetClass,
|
||||||
run.front().peers.front(),
|
run.front().peers.front(),
|
||||||
sourceLane,
|
sourceLane,
|
||||||
group.resultIndices,
|
group.resultIndices,
|
||||||
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
|
CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex});
|
||||||
if (failed(produced))
|
if (failed(produced))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
yielded.reserve(produced->size());
|
yielded.reserve(produced->size());
|
||||||
for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
|
for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
|
||||||
auto fragmentType = cast<RankedTensorType>(output.getType());
|
auto fragmentType = cast<RankedTensorType>(output.getType());
|
||||||
Value acc = iterArgs[outputIndex];
|
Value acc = iterArgs[outputIndex];
|
||||||
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
|
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
|
||||||
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
|
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
if (failed(loop))
|
if (failed(loop))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -4466,14 +4634,14 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
|
bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) {
|
||||||
for (const auto& [key, consumers] : state.sameClassConsumers) {
|
SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId};
|
||||||
if (!consumers.contains(classId))
|
auto it = state.sameClassConsumerIndex.find(lookupKey);
|
||||||
continue;
|
if (it == state.sameClassConsumerIndex.end())
|
||||||
if (!sameProducerResult(key, producerKey))
|
return false;
|
||||||
continue;
|
|
||||||
if (containsProducerKey(key, producerKey) || containsProducerKey(producerKey, key))
|
for (ProducerKey existing : it->second)
|
||||||
|
if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing))
|
||||||
return true;
|
return true;
|
||||||
}
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4488,6 +4656,7 @@ bool canCompactBatchClassRun(MaterializerState& state,
|
|||||||
ArrayRef<Value> outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front());
|
ArrayRef<Value> outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front());
|
||||||
|
|
||||||
for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) {
|
for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) {
|
||||||
|
(void) ignored;
|
||||||
for (const MaterializationRunSlot& slot : run) {
|
for (const MaterializationRunSlot& slot : run) {
|
||||||
if (slot.peers.empty())
|
if (slot.peers.empty())
|
||||||
return false;
|
return false;
|
||||||
@@ -4533,7 +4702,8 @@ Value createBatchClassRunSourceLane(MaterializerState& state,
|
|||||||
SmallVector<int64_t, 16> sourceLanes;
|
SmallVector<int64_t, 16> sourceLanes;
|
||||||
sourceLanes.reserve(run.size() * targetClass.cpus.size());
|
sourceLanes.reserve(run.size() * targetClass.cpus.size());
|
||||||
|
|
||||||
for (const MaterializationRunSlot& slot : run) {
|
for (auto [runSlotIndex, slot] : llvm::enumerate(run)) {
|
||||||
|
(void) runSlotIndex;
|
||||||
assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane");
|
assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane");
|
||||||
for (const ComputeInstance& peer : slot.peers)
|
for (const ComputeInstance& peer : slot.peers)
|
||||||
sourceLanes.push_back(peer.laneStart);
|
sourceLanes.push_back(peer.laneStart);
|
||||||
@@ -4577,7 +4747,6 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
|
|||||||
plan.messages.targetCoreIds.reserve(messageCount);
|
plan.messages.targetCoreIds.reserve(messageCount);
|
||||||
|
|
||||||
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
|
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
|
||||||
(void) slotIndex;
|
|
||||||
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
||||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
|
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
|
||||||
if (failed(checkedSourceCpu))
|
if (failed(checkedSourceCpu))
|
||||||
@@ -4590,6 +4759,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
|
|||||||
return failure();
|
return failure();
|
||||||
plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
||||||
}
|
}
|
||||||
|
(void) slotIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
plans.push_back(std::move(plan));
|
plans.push_back(std::move(plan));
|
||||||
@@ -4773,7 +4943,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) {
|
LogicalResult materializeInstanceSlot(MaterializerState& state,
|
||||||
|
const ComputeInstance& instance) {
|
||||||
auto cpuIt = state.schedule.computeToCpuMap.find(instance);
|
auto cpuIt = state.schedule.computeToCpuMap.find(instance);
|
||||||
if (cpuIt == state.schedule.computeToCpuMap.end())
|
if (cpuIt == state.schedule.computeToCpuMap.end())
|
||||||
return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance");
|
return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance");
|
||||||
@@ -4794,8 +4965,7 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
|
|||||||
return success();
|
return success();
|
||||||
|
|
||||||
if (isa<SpatComputeBatch>(instance.op)) {
|
if (isa<SpatComputeBatch>(instance.op)) {
|
||||||
FailureOr<MaterializationRun> run =
|
FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
|
||||||
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
|
|
||||||
|
|
||||||
if (succeeded(run)) {
|
if (succeeded(run)) {
|
||||||
if (!targetClass.isBatch)
|
if (!targetClass.isBatch)
|
||||||
@@ -4924,6 +5094,7 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody());
|
LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody());
|
||||||
|
(void) _;
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
#include "mlir/Analysis/TopologicalSortUtils.h"
|
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
@@ -14,20 +13,14 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <chrono>
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdlib>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <tuple>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -51,83 +44,6 @@ using SpatCompute = spatial::SpatCompute;
|
|||||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||||
using spatial::getProducerValueRef;
|
using spatial::getProducerValueRef;
|
||||||
|
|
||||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
|
||||||
|
|
||||||
class ScopedMergePhaseTimer {
|
|
||||||
public:
|
|
||||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
|
||||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
|
||||||
if (enabled)
|
|
||||||
start = std::chrono::steady_clock::now();
|
|
||||||
}
|
|
||||||
|
|
||||||
~ScopedMergePhaseTimer() {
|
|
||||||
if (!enabled)
|
|
||||||
return;
|
|
||||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
|
||||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
|
||||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool enabled = false;
|
|
||||||
std::string phase;
|
|
||||||
std::chrono::steady_clock::time_point start;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MergeIrCounts {
|
|
||||||
uint64_t topLevelComputeCount = 0;
|
|
||||||
uint64_t topLevelComputeBatchCount = 0;
|
|
||||||
uint64_t scalarChannelSendCount = 0;
|
|
||||||
uint64_t scalarChannelReceiveCount = 0;
|
|
||||||
uint64_t wvmmCount = 0;
|
|
||||||
uint64_t vaddCount = 0;
|
|
||||||
uint64_t scfForCount = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
|
|
||||||
MergeIrCounts counts;
|
|
||||||
|
|
||||||
auto countComputeBodyOps = [&](Operation* op) {
|
|
||||||
op->walk([&](Operation* nestedOp) {
|
|
||||||
if (isa<spatial::SpatChannelSendOp>(nestedOp))
|
|
||||||
++counts.scalarChannelSendCount;
|
|
||||||
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
|
|
||||||
++counts.scalarChannelReceiveCount;
|
|
||||||
else if (isa<spatial::SpatVMMOp>(nestedOp))
|
|
||||||
++counts.wvmmCount;
|
|
||||||
else if (isa<spatial::SpatVAddOp>(nestedOp))
|
|
||||||
++counts.vaddCount;
|
|
||||||
else if (isa<scf::ForOp>(nestedOp))
|
|
||||||
++counts.scfForCount;
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<SpatCompute>()) {
|
|
||||||
++counts.topLevelComputeCount;
|
|
||||||
countComputeBodyOps(compute.getOperation());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto batch : funcOp.getOps<SpatComputeBatch>()) {
|
|
||||||
++counts.topLevelComputeBatchCount;
|
|
||||||
countComputeBodyOps(batch.getOperation());
|
|
||||||
}
|
|
||||||
|
|
||||||
return counts;
|
|
||||||
}
|
|
||||||
|
|
||||||
void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
|
||||||
if (!isMergeProfilingEnabled())
|
|
||||||
return;
|
|
||||||
|
|
||||||
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
|
||||||
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
|
||||||
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
|
||||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
|
||||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
|
||||||
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||||
@@ -138,16 +54,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ComputeMotifInfo {
|
|
||||||
uint64_t instructionCount = 0;
|
|
||||||
uint64_t weightedVmmCount = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
void appendUnique(SmallVector<size_t>& values, size_t value) {
|
|
||||||
if (!llvm::is_contained(values, value))
|
|
||||||
values.push_back(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||||
if (!compute->hasOneUse())
|
if (!compute->hasOneUse())
|
||||||
return false;
|
return false;
|
||||||
@@ -266,212 +172,6 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void emitMotifProfile(func::FuncOp funcOp) {
|
|
||||||
if (!std::getenv("DCP_MOTIF_PROFILE"))
|
|
||||||
return;
|
|
||||||
|
|
||||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
|
||||||
DenseMap<SpatCompute, size_t> computeToIndex;
|
|
||||||
computeToIndex.reserve(computes.size());
|
|
||||||
for (auto [index, compute] : llvm::enumerate(computes))
|
|
||||||
computeToIndex[compute] = index;
|
|
||||||
|
|
||||||
SmallVector<ComputeMotifInfo> computeInfos(computes.size());
|
|
||||||
SmallVector<SmallVector<size_t>> parents(computes.size());
|
|
||||||
SmallVector<SmallVector<size_t>> children(computes.size());
|
|
||||||
|
|
||||||
uint64_t weightedVmmNodeCount = 0;
|
|
||||||
uint64_t weightedVmmOpCount = 0;
|
|
||||||
uint64_t edgeCount = 0;
|
|
||||||
|
|
||||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
|
||||||
ComputeMotifInfo& info = computeInfos[index];
|
|
||||||
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
|
|
||||||
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
|
|
||||||
if (info.weightedVmmCount > 0) {
|
|
||||||
weightedVmmNodeCount++;
|
|
||||||
weightedVmmOpCount += info.weightedVmmCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (Value input : compute.getInputs()) {
|
|
||||||
auto parent = dyn_cast<SpatCompute>(input.getDefiningOp());
|
|
||||||
if (!parent || parent == compute)
|
|
||||||
continue;
|
|
||||||
auto parentIt = computeToIndex.find(parent);
|
|
||||||
if (parentIt == computeToIndex.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
size_t parentIndex = parentIt->second;
|
|
||||||
size_t oldParentCount = parents[index].size();
|
|
||||||
appendUnique(parents[index], parentIndex);
|
|
||||||
if (parents[index].size() != oldParentCount) {
|
|
||||||
appendUnique(children[parentIndex], index);
|
|
||||||
edgeCount++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t maxFanIn = 0;
|
|
||||||
uint64_t maxFanOut = 0;
|
|
||||||
uint64_t fanIn16 = 0;
|
|
||||||
uint64_t fanIn64 = 0;
|
|
||||||
uint64_t fanIn256 = 0;
|
|
||||||
uint64_t fanOut16 = 0;
|
|
||||||
uint64_t fanOut64 = 0;
|
|
||||||
uint64_t fanOut256 = 0;
|
|
||||||
for (size_t index = 0; index < computes.size(); ++index) {
|
|
||||||
uint64_t fanIn = parents[index].size();
|
|
||||||
uint64_t fanOut = children[index].size();
|
|
||||||
maxFanIn = std::max(maxFanIn, fanIn);
|
|
||||||
maxFanOut = std::max(maxFanOut, fanOut);
|
|
||||||
fanIn16 += fanIn >= 16;
|
|
||||||
fanIn64 += fanIn >= 64;
|
|
||||||
fanIn256 += fanIn >= 256;
|
|
||||||
fanOut16 += fanOut >= 16;
|
|
||||||
fanOut64 += fanOut >= 64;
|
|
||||||
fanOut256 += fanOut >= 256;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t serialChainCount = 0;
|
|
||||||
uint64_t serialChainNodeCount = 0;
|
|
||||||
uint64_t maxSerialChain = 0;
|
|
||||||
for (size_t index = 0; index < computes.size(); ++index) {
|
|
||||||
if (parents[index].size() == 1 && children[parents[index][0]].size() == 1)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
uint64_t chainLength = 1;
|
|
||||||
size_t current = index;
|
|
||||||
while (children[current].size() == 1) {
|
|
||||||
size_t child = children[current][0];
|
|
||||||
if (parents[child].size() != 1)
|
|
||||||
break;
|
|
||||||
chainLength++;
|
|
||||||
current = child;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (chainLength >= 2) {
|
|
||||||
serialChainCount++;
|
|
||||||
serialChainNodeCount += chainLength;
|
|
||||||
maxSerialChain = std::max(maxSerialChain, chainLength);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<size_t> incomingEdgeCount;
|
|
||||||
incomingEdgeCount.reserve(parents.size());
|
|
||||||
for (ArrayRef<size_t> parentList : parents)
|
|
||||||
incomingEdgeCount.push_back(parentList.size());
|
|
||||||
|
|
||||||
SmallVector<uint64_t> level(computes.size(), 0);
|
|
||||||
SmallVector<size_t> readyNodes;
|
|
||||||
readyNodes.reserve(computes.size());
|
|
||||||
for (size_t index = 0; index < computes.size(); ++index)
|
|
||||||
if (incomingEdgeCount[index] == 0)
|
|
||||||
readyNodes.push_back(index);
|
|
||||||
|
|
||||||
size_t readyIndex = 0;
|
|
||||||
while (readyIndex != readyNodes.size()) {
|
|
||||||
size_t current = readyNodes[readyIndex++];
|
|
||||||
for (size_t child : children[current]) {
|
|
||||||
level[child] = std::max(level[child], level[current] + 1);
|
|
||||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
|
||||||
incomingEdgeCount[child]--;
|
|
||||||
if (incomingEdgeCount[child] == 0)
|
|
||||||
readyNodes.push_back(child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<uint64_t> weightedVmmNodesByLevel;
|
|
||||||
for (size_t index = 0; index < computes.size(); ++index) {
|
|
||||||
if (computeInfos[index].weightedVmmCount == 0)
|
|
||||||
continue;
|
|
||||||
if (weightedVmmNodesByLevel.size() <= level[index])
|
|
||||||
weightedVmmNodesByLevel.resize(level[index] + 1, 0);
|
|
||||||
weightedVmmNodesByLevel[level[index]]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t maxWeightedVmmLevel = 0;
|
|
||||||
uint64_t wideWeightedVmmLevels64 = 0;
|
|
||||||
uint64_t wideWeightedVmmLevels256 = 0;
|
|
||||||
for (uint64_t count : weightedVmmNodesByLevel) {
|
|
||||||
maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count);
|
|
||||||
wideWeightedVmmLevels64 += count >= 64;
|
|
||||||
wideWeightedVmmLevels256 += count >= 256;
|
|
||||||
}
|
|
||||||
|
|
||||||
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
|
|
||||||
SmallVector<ShapeKey> weightedVmmShapeKeys;
|
|
||||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
|
||||||
const ComputeMotifInfo& info = computeInfos[index];
|
|
||||||
if (info.weightedVmmCount == 0)
|
|
||||||
continue;
|
|
||||||
weightedVmmShapeKeys.push_back({info.instructionCount,
|
|
||||||
info.weightedVmmCount,
|
|
||||||
static_cast<uint64_t>(compute.getWeights().size()),
|
|
||||||
static_cast<uint64_t>(compute.getInputs().size()),
|
|
||||||
static_cast<uint64_t>(parents[index].size()),
|
|
||||||
static_cast<uint64_t>(children[index].size())});
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::sort(weightedVmmShapeKeys);
|
|
||||||
SmallVector<std::pair<uint64_t, ShapeKey>> weightedVmmShapeCounts;
|
|
||||||
for (size_t index = 0; index < weightedVmmShapeKeys.size();) {
|
|
||||||
size_t next = index + 1;
|
|
||||||
while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index])
|
|
||||||
next++;
|
|
||||||
weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]});
|
|
||||||
index = next;
|
|
||||||
}
|
|
||||||
llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) {
|
|
||||||
if (lhs.first != rhs.first)
|
|
||||||
return lhs.first > rhs.first;
|
|
||||||
return lhs.second < rhs.second;
|
|
||||||
});
|
|
||||||
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} "
|
|
||||||
"serialChains={4} serialChainNodes={5} maxSerialChain={6} "
|
|
||||||
"maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} "
|
|
||||||
"fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n",
|
|
||||||
computes.size(),
|
|
||||||
edgeCount,
|
|
||||||
weightedVmmNodeCount,
|
|
||||||
weightedVmmOpCount,
|
|
||||||
serialChainCount,
|
|
||||||
serialChainNodeCount,
|
|
||||||
maxSerialChain,
|
|
||||||
maxFanIn,
|
|
||||||
maxFanOut,
|
|
||||||
fanIn16,
|
|
||||||
fanIn64,
|
|
||||||
fanIn256,
|
|
||||||
fanOut16,
|
|
||||||
fanOut64,
|
|
||||||
fanOut256,
|
|
||||||
readyNodes.size());
|
|
||||||
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} "
|
|
||||||
"shapeGroups={4}\n",
|
|
||||||
weightedVmmNodesByLevel.size(),
|
|
||||||
maxWeightedVmmLevel,
|
|
||||||
wideWeightedVmmLevels64,
|
|
||||||
wideWeightedVmmLevels256,
|
|
||||||
weightedVmmShapeCounts.size());
|
|
||||||
|
|
||||||
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
|
|
||||||
auto [count, shape] = weightedVmmShapeCounts[rank];
|
|
||||||
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
|
|
||||||
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
|
|
||||||
rank,
|
|
||||||
count,
|
|
||||||
insts,
|
|
||||||
vmmOps,
|
|
||||||
weights,
|
|
||||||
inputs,
|
|
||||||
fanIn,
|
|
||||||
fanOut);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) {
|
void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) {
|
||||||
std::fstream file = openReportFile(name);
|
std::fstream file = openReportFile(name);
|
||||||
if (!file.is_open())
|
if (!file.is_open())
|
||||||
@@ -628,44 +328,27 @@ public:
|
|||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
{
|
mergeTriviallyConnectedComputes(func);
|
||||||
ScopedMergePhaseTimer timer("trivial-serial-merge");
|
|
||||||
mergeTriviallyConnectedComputes(func);
|
|
||||||
}
|
|
||||||
if (std::getenv("DCP_MOTIF_PROFILE"))
|
|
||||||
emitMotifProfile(func);
|
|
||||||
|
|
||||||
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
||||||
{
|
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||||
ScopedMergePhaseTimer timer("scheduling-analysis");
|
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
|
||||||
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
signalPassFailure();
|
||||||
}
|
return;
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("schedule-materialization");
|
|
||||||
if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) {
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
emitMergeIrCounts("after-materialization", func);
|
if (!sortTopologically(&func.getBody().front())) {
|
||||||
|
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||||
{
|
signalPassFailure();
|
||||||
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
return;
|
||||||
if (!sortTopologically(&func.getBody().front())) {
|
|
||||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (failed(verifySpatialCommunicationInvariants(func))) {
|
|
||||||
func.emitOpError("merged Spatial communication invariant verification failed");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
emitMergeIrCounts("final-post-merge", func);
|
|
||||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
|
|
||||||
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
|
||||||
}
|
}
|
||||||
|
if (failed(verifySpatialCommunicationInvariants(func))) {
|
||||||
|
func.emitOpError("merged Spatial communication invariant verification failed");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
|
||||||
|
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user