automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-03 19:43:56 +02:00
parent dc5edd032c
commit 69021d56aa
12 changed files with 187 additions and 195 deletions
+26 -24
View File
@@ -98,8 +98,8 @@ static int32_t getVectorByteSizeOrCrash(ShapedType type) {
return pim::checkedI32OrCrash(*byteSize, "vector byte size");
}
static Operation *getDiagnosticAnchor(mlir::Value value) {
if (Operation *definingOp = value.getDefiningOp())
static Operation* getDiagnosticAnchor(mlir::Value value) {
if (Operation* definingOp = value.getDefiningOp())
return definingOp;
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParentOp();
@@ -111,7 +111,7 @@ static Operation *getDiagnosticAnchor(mlir::Value value) {
// the non-negative int32_t range.
static constexpr size_t kPimAddressLimit = static_cast<size_t>(std::numeric_limits<int32_t>::max());
static FailureOr<size_t> checkedAlignTo(size_t value, size_t alignment, Operation *anchor, StringRef fieldName) {
static FailureOr<size_t> checkedAlignTo(size_t value, size_t alignment, Operation* anchor, StringRef fieldName) {
if (alignment == 0)
return value;
size_t remainder = value % alignment;
@@ -121,7 +121,7 @@ static FailureOr<size_t> checkedAlignTo(size_t value, size_t alignment, Operatio
}
static void printMemoryOverflowDiagnostic(mlir::Value value,
const MemoryValueKey &key,
const MemoryValueKey& key,
size_t requestedSize,
size_t currentFirstAvailableAddress,
size_t alignedEndAddress) {
@@ -136,7 +136,7 @@ static void printMemoryOverflowDiagnostic(mlir::Value value,
value.print(llvm::errs());
llvm::errs() << "\n";
llvm::errs() << "Value type: " << value.getType() << "\n";
if (Operation *definingOp = value.getDefiningOp()) {
if (Operation* definingOp = value.getDefiningOp()) {
llvm::errs() << "Defining op:\n";
definingOp->print(llvm::errs());
llvm::errs() << "\n";
@@ -170,7 +170,7 @@ void PimMemory::allocateGatheredMemory() {
void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind) {
memEntry.address = firstAvailableAddress;
Operation *anchor = getDiagnosticAnchor(key.value);
Operation* anchor = getDiagnosticAnchor(key.value);
auto checkedEnd = pim::checkedAdd(memEntry.address, memEntry.size, anchor, "local memory end");
FailureOr<size_t> checkedAlignedEnd = failure();
if (succeeded(checkedEnd))
@@ -179,12 +179,11 @@ void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memE
bool endFits = succeeded(checkedEnd) && *checkedEnd <= kPimAddressLimit;
bool alignedEndFits = succeeded(checkedAlignedEnd) && *checkedAlignedEnd <= kPimAddressLimit;
if (!startFits || !endFits || !alignedEndFits) {
printMemoryOverflowDiagnostic(
key.value,
key,
memEntry.size,
firstAvailableAddress,
succeeded(checkedAlignedEnd) ? *checkedAlignedEnd : kPimAddressLimit);
printMemoryOverflowDiagnostic(key.value,
key,
memEntry.size,
firstAvailableAddress,
succeeded(checkedAlignedEnd) ? *checkedAlignedEnd : kPimAddressLimit);
llvm_unreachable("PIM local memory allocation overflow");
}
firstAvailableAddress = *checkedAlignedEnd;
@@ -209,7 +208,7 @@ PhysicalSlotInfo PimMemory::allocatePhysicalSlot(size_t slotSize, const MemoryVa
slot.address = firstAvailableAddress;
slot.size = slotSize;
Operation *anchor = getDiagnosticAnchor(key.value);
Operation* anchor = getDiagnosticAnchor(key.value);
auto checkedEnd = pim::checkedAdd(slot.address, slot.size, anchor, "local memory end");
FailureOr<size_t> checkedAlignedEnd = failure();
if (succeeded(checkedEnd))
@@ -218,8 +217,11 @@ PhysicalSlotInfo PimMemory::allocatePhysicalSlot(size_t slotSize, const MemoryVa
bool endFits = succeeded(checkedEnd) && *checkedEnd <= kPimAddressLimit;
bool alignedEndFits = succeeded(checkedAlignedEnd) && *checkedAlignedEnd <= kPimAddressLimit;
if (!startFits || !endFits || !alignedEndFits) {
printMemoryOverflowDiagnostic(
key.value, key, slot.size, firstAvailableAddress, succeeded(checkedAlignedEnd) ? *checkedAlignedEnd : kPimAddressLimit);
printMemoryOverflowDiagnostic(key.value,
key,
slot.size,
firstAvailableAddress,
succeeded(checkedAlignedEnd) ? *checkedAlignedEnd : kPimAddressLimit);
llvm_unreachable("PIM local memory allocation overflow");
}
@@ -273,8 +275,8 @@ void PimMemory::allocateCore(Operation* op, std::optional<unsigned> lane) {
SmallVector<size_t> slotOrder(plannedSlots.size());
std::iota(slotOrder.begin(), slotOrder.end(), 0);
llvm::stable_sort(slotOrder, [&](size_t lhsIndex, size_t rhsIndex) {
const PlannedPhysicalSlot &lhs = plannedSlots[lhsIndex];
const PlannedPhysicalSlot &rhs = plannedSlots[rhsIndex];
const PlannedPhysicalSlot& lhs = plannedSlots[lhsIndex];
const PlannedPhysicalSlot& rhs = plannedSlots[rhsIndex];
if (lhs.requiredSize != rhs.requiredSize)
return lhs.requiredSize > rhs.requiredSize;
return lhs.id < rhs.id;
@@ -282,7 +284,7 @@ void PimMemory::allocateCore(Operation* op, std::optional<unsigned> lane) {
SmallVector<bool, 16> usedExistingSlots(localPhysicalSlots.size(), false);
for (size_t slotIndex : slotOrder) {
PlannedPhysicalSlot &slot = plannedSlots[slotIndex];
PlannedPhysicalSlot& slot = plannedSlots[slotIndex];
size_t bestExistingIndex = std::numeric_limits<size_t>::max();
auto bestKey = std::tuple<size_t, size_t, size_t>(
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max());
@@ -290,11 +292,11 @@ void PimMemory::allocateCore(Operation* op, std::optional<unsigned> lane) {
for (size_t existingIndex = 0; existingIndex < localPhysicalSlots.size(); ++existingIndex) {
if (usedExistingSlots[existingIndex])
continue;
const PhysicalSlotInfo &existingSlot = localPhysicalSlots[existingIndex];
const PhysicalSlotInfo& existingSlot = localPhysicalSlots[existingIndex];
if (existingSlot.size < slot.requiredSize)
continue;
auto candidateKey = std::tuple<size_t, size_t, size_t>(
existingSlot.size - slot.requiredSize, existingSlot.size, existingSlot.id);
auto candidateKey =
std::tuple<size_t, size_t, size_t>(existingSlot.size - slot.requiredSize, existingSlot.size, existingSlot.id);
if (candidateKey < bestKey) {
bestKey = candidateKey;
bestExistingIndex = existingIndex;
@@ -302,7 +304,7 @@ void PimMemory::allocateCore(Operation* op, std::optional<unsigned> lane) {
}
if (bestExistingIndex != std::numeric_limits<size_t>::max()) {
const PhysicalSlotInfo &existingSlot = localPhysicalSlots[bestExistingIndex];
const PhysicalSlotInfo& existingSlot = localPhysicalSlots[bestExistingIndex];
slot.id = existingSlot.id;
slot.address = existingSlot.address;
slot.size = existingSlot.size;
@@ -317,7 +319,7 @@ void PimMemory::allocateCore(Operation* op, std::optional<unsigned> lane) {
}
for (size_t intervalIndex : slot.intervalIndices) {
LocalAllocInterval &interval = intervals[intervalIndex];
LocalAllocInterval& interval = intervals[intervalIndex];
interval.physicalSlotId = slot.id;
interval.assignedAddress = slot.address;
interval.physicalSlotSize = slot.size;
@@ -375,7 +377,7 @@ MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row = reportRow;
row.numAlloca = localPhysicalSlots.size();
row.sizeAlloca = 0;
for (const PhysicalSlotInfo &slot : localPhysicalSlots)
for (const PhysicalSlotInfo& slot : localPhysicalSlots)
row.sizeAlloca += slot.size;
return row;
}
+2 -1
View File
@@ -26,7 +26,8 @@ llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
"pim-memory-report",
llvm::cl::desc("Emit a human-readable PIM memory planning report"),
llvm::cl::values(clEnumValN(PimMemoryReportNone, "none", "Do not emit any PIM memory planning report")),
llvm::cl::values(clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")),
llvm::cl::values(
clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")),
llvm::cl::values(clEnumValN(PimMemoryReportFull, "full", "Emit the full detailed PIM memory planning report")),
llvm::cl::init(PimMemoryReportNone),
llvm::cl::cat(OnnxMlirOptions));
+87 -96
View File
@@ -5,8 +5,8 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
@@ -42,10 +42,10 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
struct MemoryTouchInterval {
uint64_t start = 0;
uint64_t end = 0;
Operation *startOp = nullptr;
Operation *endOp = nullptr;
Operation *firstTouchOp = nullptr;
Operation *lastTouchOp = nullptr;
Operation* startOp = nullptr;
Operation* endOp = nullptr;
Operation* firstTouchOp = nullptr;
Operation* lastTouchOp = nullptr;
uint64_t firstTouchPosition = 0;
uint64_t lastTouchPosition = 0;
bool hasRuntimeUse = false;
@@ -57,8 +57,8 @@ struct MemoryTouchInterval {
};
struct OperationOrdering {
llvm::DenseMap<Operation *, uint64_t> position;
llvm::DenseMap<Operation *, uint64_t> subtreeEnd;
llvm::DenseMap<Operation*, uint64_t> position;
llvm::DenseMap<Operation*, uint64_t> subtreeEnd;
uint64_t nextPosition = 0;
};
@@ -70,7 +70,7 @@ static std::string printValueToString(mlir::Value value) {
return text;
}
static std::string printOperationToString(Operation *op) {
static std::string printOperationToString(Operation* op) {
if (!op)
return "<none>";
std::string text;
@@ -116,7 +116,7 @@ static std::string summarizeValue(mlir::Value value, size_t maxLen = 72) {
return abbreviate(collapseWhitespace(printValueToString(value)), maxLen);
}
static std::string summarizeOperation(Operation *op, size_t maxLen = 96) {
static std::string summarizeOperation(Operation* op, size_t maxLen = 96) {
if (!op)
return "<none>";
std::string prefix = op->getName().getStringRef().str();
@@ -130,34 +130,34 @@ static std::string summarizeLocation(Location loc, size_t maxLen = 88) {
return abbreviate(collapseWhitespace(printLocationToString(loc)), maxLen);
}
static void assignOperationOrdering(Operation *op, OperationOrdering &ordering) {
static void assignOperationOrdering(Operation* op, OperationOrdering& ordering) {
uint64_t position = ordering.nextPosition++;
ordering.position[op] = position;
uint64_t end = position;
for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &nestedOp : block) {
for (Region& region : op->getRegions())
for (Block& block : region)
for (Operation& nestedOp : block) {
assignOperationOrdering(&nestedOp, ordering);
end = std::max(end, ordering.subtreeEnd.lookup(&nestedOp));
}
ordering.subtreeEnd[op] = end;
}
static OperationOrdering buildOperationOrdering(Operation *coreLikeOp) {
static OperationOrdering buildOperationOrdering(Operation* coreLikeOp) {
OperationOrdering ordering;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return ordering;
for (Operation &op : coreLikeOp->getRegion(0).front())
for (Operation& op : coreLikeOp->getRegion(0).front())
assignOperationOrdering(&op, ordering);
return ordering;
}
static bool isSupportedAliasOp(Operation *op) {
static bool isSupportedAliasOp(Operation* op) {
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
}
static bool isRuntimeMemoryTouchOp(Operation *op) {
static bool isRuntimeMemoryTouchOp(Operation* op) {
return isa<pim::PimMemCopyHostToDevOp,
pim::PimMemCopyDevToHostOp,
pim::PimMemCopyOp,
@@ -178,27 +178,27 @@ static bool isRuntimeMemoryTouchOp(Operation *op) {
pim::PimVSoftmaxOp>(op);
}
static bool isIgnoredLivenessUser(Operation *op) {
static bool isIgnoredLivenessUser(Operation* op) {
return isSupportedAliasOp(op) || isa<scf::ForOp, scf::YieldOp, memref::DeallocOp>(op) || isCoreStaticAddressOp(op);
}
static bool isWithin(mlir::Value value, Region *region) {
static bool isWithin(mlir::Value value, Region* region) {
if (!region)
return false;
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent() == region;
if (Operation *definingOp = value.getDefiningOp())
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion() == region || region->isAncestor(definingOp->getParentRegion());
return false;
}
static bool isNestedAllocation(Operation *coreLikeOp, memref::AllocOp allocOp) {
static bool isNestedAllocation(Operation* coreLikeOp, memref::AllocOp allocOp) {
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return false;
return allocOp->getBlock() != &coreLikeOp->getRegion(0).front();
}
static void addFallbackReason(std::string &reason, StringRef newReason) {
static void addFallbackReason(std::string& reason, StringRef newReason) {
if (newReason.empty())
return;
if (!reason.empty())
@@ -206,7 +206,7 @@ static void addFallbackReason(std::string &reason, StringRef newReason) {
reason += newReason.str();
}
static void appendAliasDescription(llvm::SmallVectorImpl<std::string> &aliases, mlir::Value value) {
static void appendAliasDescription(llvm::SmallVectorImpl<std::string>& aliases, mlir::Value value) {
std::string text = printValueToString(value);
if (!llvm::is_contained(aliases, text))
aliases.push_back(std::move(text));
@@ -215,16 +215,15 @@ static void appendAliasDescription(llvm::SmallVectorImpl<std::string> &aliases,
struct OrderedTouchRange {
uint64_t start = 0;
uint64_t end = 0;
Operation *startOp = nullptr;
Operation *endOp = nullptr;
Operation* startOp = nullptr;
Operation* endOp = nullptr;
bool escapedLoop = false;
};
static OrderedTouchRange
getEffectiveTouchRange(mlir::Value definingValue, Operation *user, const OperationOrdering &ordering) {
OrderedTouchRange range {
ordering.position.lookup(user), ordering.position.lookup(user), user, user, false};
for (Operation *current = user; current; current = current->getParentOp()) {
getEffectiveTouchRange(mlir::Value definingValue, Operation* user, const OperationOrdering& ordering) {
OrderedTouchRange range {ordering.position.lookup(user), ordering.position.lookup(user), user, user, false};
for (Operation* current = user; current; current = current->getParentOp()) {
auto forOp = dyn_cast<scf::ForOp>(current);
if (!forOp || isWithin(definingValue, &forOp.getRegion()))
continue;
@@ -238,7 +237,7 @@ getEffectiveTouchRange(mlir::Value definingValue, Operation *user, const Operati
}
static MemoryTouchInterval
computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering &ordering, uint64_t fallbackEnd) {
computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering& ordering, uint64_t fallbackEnd) {
MemoryTouchInterval interval;
interval.start = ordering.position.lookup(allocOp);
interval.end = interval.start;
@@ -246,7 +245,7 @@ computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering &ord
interval.endOp = allocOp;
SmallPtrSet<mlir::Value, 16> visitedValues;
SmallPtrSet<Operation *, 32> visitedUsers;
SmallPtrSet<Operation*, 32> visitedUsers;
SmallVector<mlir::Value> pendingValues;
pendingValues.push_back(allocOp.getResult());
auto parentLoop = allocOp->getParentOfType<scf::ForOp>();
@@ -256,7 +255,7 @@ computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering &ord
if (!visitedValues.insert(value).second)
continue;
for (Operation *user : value.getUsers()) {
for (Operation* user : value.getUsers()) {
if (!visitedUsers.insert(user).second)
continue;
@@ -269,7 +268,7 @@ computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering &ord
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) {
OpOperand *tiedOperand = dpsOp.getTiedOpOperand(result);
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
if (!tiedOperand || tiedOperand->get() != value)
continue;
pendingValues.push_back(result);
@@ -379,11 +378,11 @@ static FailureOr<size_t> getAllocSizeBytes(memref::AllocOp allocOp) {
return pim::checkedSize(*checkedBytes, allocOp, "memory allocation byte size");
}
static bool intervalsOverlap(const LocalAllocInterval &lhs, const LocalAllocInterval &rhs) {
static bool intervalsOverlap(const LocalAllocInterval& lhs, const LocalAllocInterval& rhs) {
return !(lhs.end < rhs.start || rhs.end < lhs.start);
}
static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot &slot, ArrayRef<LocalAllocInterval> intervals) {
static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot& slot, ArrayRef<LocalAllocInterval> intervals) {
uint64_t slotLogicalBytes = 0;
for (size_t intervalIndex : slot.intervalIndices)
slotLogicalBytes += intervals[intervalIndex].size;
@@ -392,7 +391,7 @@ static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot &slot, ArrayRef<Lo
} // namespace
SmallVector<LocalAllocInterval, 0> onnx_mlir::buildLocalAllocIntervals(Operation *coreLikeOp,
SmallVector<LocalAllocInterval, 0> onnx_mlir::buildLocalAllocIntervals(Operation* coreLikeOp,
std::optional<unsigned> lane) {
SmallVector<LocalAllocInterval, 0> intervals;
OperationOrdering ordering = buildOperationOrdering(coreLikeOp);
@@ -442,8 +441,8 @@ SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef
SmallVector<size_t> intervalOrder(intervals.size());
std::iota(intervalOrder.begin(), intervalOrder.end(), 0);
llvm::stable_sort(intervalOrder, [&](size_t lhsIndex, size_t rhsIndex) {
const LocalAllocInterval &lhs = intervals[lhsIndex];
const LocalAllocInterval &rhs = intervals[rhsIndex];
const LocalAllocInterval& lhs = intervals[lhsIndex];
const LocalAllocInterval& rhs = intervals[rhsIndex];
if (lhs.size != rhs.size)
return lhs.size > rhs.size;
if (lhs.start != rhs.start)
@@ -454,16 +453,15 @@ SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef
});
for (size_t intervalIndex : intervalOrder) {
LocalAllocInterval &interval = intervals[intervalIndex];
PlannedPhysicalSlot *bestSlot = nullptr;
auto bestKey = std::tuple<size_t, size_t, size_t, size_t>(
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max());
LocalAllocInterval& interval = intervals[intervalIndex];
PlannedPhysicalSlot* bestSlot = nullptr;
auto bestKey = std::tuple<size_t, size_t, size_t, size_t>(std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max());
for (size_t slotIndex = 0; slotIndex < slots.size(); ++slotIndex) {
PlannedPhysicalSlot &slot = slots[slotIndex];
PlannedPhysicalSlot& slot = slots[slotIndex];
bool compatible = true;
for (size_t otherIndex : slot.intervalIndices) {
if (intervalsOverlap(interval, intervals[otherIndex])) {
@@ -476,8 +474,8 @@ SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef
size_t resultingSize = std::max(slot.requiredSize, interval.size);
size_t growth = resultingSize - slot.requiredSize;
auto candidateKey = std::tuple<size_t, size_t, size_t, size_t>(
growth, resultingSize, slot.intervalIndices.size(), slot.id);
auto candidateKey =
std::tuple<size_t, size_t, size_t, size_t>(growth, resultingSize, slot.intervalIndices.size(), slot.id);
if (candidateKey < bestKey) {
bestKey = candidateKey;
bestSlot = &slot;
@@ -503,7 +501,7 @@ SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef
return slots;
}
MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation* coreLikeOp,
std::optional<unsigned> lane,
ArrayRef<LocalAllocInterval> intervals,
ArrayRef<PlannedPhysicalSlot> slots,
@@ -522,7 +520,7 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
size_t largestPhysicalSlot = 0;
size_t maximumAssignedAddress = 0;
for (const LocalAllocInterval &interval : intervals) {
for (const LocalAllocInterval& interval : intervals) {
totalLogicalBytes += interval.size;
largestLogicalAllocation = std::max(largestLogicalAllocation, interval.size);
maximumAssignedAddress = std::max(maximumAssignedAddress, interval.assignedAddress + interval.physicalSlotSize);
@@ -535,7 +533,7 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
if (interval.escapesLoop)
++loopEscapingIntervals;
}
for (const PlannedPhysicalSlot &slot : slots) {
for (const PlannedPhysicalSlot& slot : slots) {
totalPhysicalBytes += slot.size;
largestPhysicalSlot = std::max(largestPhysicalSlot, slot.size);
if (slot.intervalIndices.size() > 1)
@@ -553,7 +551,8 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
os << "Lane: " << *lane << "\n";
os << "Summary:\n";
os << " logical allocation bytes: " << formatReportMemory(totalLogicalBytes) << " (" << totalLogicalBytes << ")\n";
os << " physical allocation bytes: " << formatReportMemory(totalPhysicalBytes) << " (" << totalPhysicalBytes << ")\n";
os << " physical allocation bytes: " << formatReportMemory(totalPhysicalBytes) << " (" << totalPhysicalBytes
<< ")\n";
os << " saved bytes: " << formatReportMemory(savedBytes) << " (" << savedBytes << ")\n";
os << " saved percent: " << format("%.2f%%", savedPercent) << "\n";
os << " intervals: " << intervals.size() << "\n";
@@ -566,7 +565,8 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
os << " largest logical allocation: " << largestLogicalAllocation << "\n";
os << " largest physical slot: " << largestPhysicalSlot << "\n";
os << " address limit: " << addressLimit << "\n";
os << " peak physical memory: " << formatReportMemory(maximumAssignedAddress) << " (" << maximumAssignedAddress << ")\n";
os << " peak physical memory: " << formatReportMemory(maximumAssignedAddress) << " (" << maximumAssignedAddress
<< ")\n";
os << " maximum assigned address: " << maximumAssignedAddress << "\n";
os << "\nHow To Read:\n";
@@ -575,16 +575,15 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
os << " Large single-use slots, fallback intervals, and nested single-use allocations are the best places\n";
os << " to inspect if allocations should be moved, sunk, or made easier to coalesce earlier in the pipeline.\n";
SmallVector<const PlannedPhysicalSlot *> reusedSlots;
SmallVector<const PlannedPhysicalSlot *> singleUseSlots;
for (const PlannedPhysicalSlot &slot : slots) {
SmallVector<const PlannedPhysicalSlot*> reusedSlots;
SmallVector<const PlannedPhysicalSlot*> singleUseSlots;
for (const PlannedPhysicalSlot& slot : slots)
if (slot.intervalIndices.size() > 1)
reusedSlots.push_back(&slot);
else
singleUseSlots.push_back(&slot);
}
llvm::stable_sort(reusedSlots, [&](const PlannedPhysicalSlot *lhs, const PlannedPhysicalSlot *rhs) {
llvm::stable_sort(reusedSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
uint64_t lhsLogicalBytes = getSlotLogicalBytes(*lhs, intervals);
uint64_t rhsLogicalBytes = getSlotLogicalBytes(*rhs, intervals);
if (lhs->intervalIndices.size() != rhs->intervalIndices.size())
@@ -595,7 +594,7 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
return lhs->size > rhs->size;
return lhs->id < rhs->id;
});
llvm::stable_sort(singleUseSlots, [&](const PlannedPhysicalSlot *lhs, const PlannedPhysicalSlot *rhs) {
llvm::stable_sort(singleUseSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
if (lhs->size != rhs->size)
return lhs->size > rhs->size;
return lhs->id < rhs->id;
@@ -607,18 +606,16 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
os << "\nBest Reuse:\n";
if (reusedSlots.empty()) {
os << " no slots were shared by multiple intervals\n";
} else {
for (const PlannedPhysicalSlot *slot : ArrayRef(reusedSlots).take_front(kSummaryReuseLimit)) {
}
else {
for (const PlannedPhysicalSlot* slot : ArrayRef(reusedSlots).take_front(kSummaryReuseLimit)) {
uint64_t slotLogicalBytes = getSlotLogicalBytes(*slot, intervals);
os << " slot #" << slot->id
<< " addr=" << slot->address
<< " size=" << formatReportMemory(slot->size)
<< " intervals=" << slot->intervalIndices.size()
<< " logical_sum=" << formatReportMemory(slotLogicalBytes) << "\n";
os << " slot #" << slot->id << " addr=" << slot->address << " size=" << formatReportMemory(slot->size)
<< " intervals=" << slot->intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
<< "\n";
for (size_t intervalIndex : slot->intervalIndices) {
const LocalAllocInterval &interval = intervals[intervalIndex];
os << " #" << interval.id
<< " [" << interval.start << "," << interval.end << "]"
const LocalAllocInterval& interval = intervals[intervalIndex];
os << " #" << interval.id << " [" << interval.start << "," << interval.end << "]"
<< " logical=" << formatReportMemory(interval.size)
<< " first=" << summarizeOperation(interval.firstTouchOp, 40)
<< " last=" << summarizeOperation(interval.lastTouchOp, 40) << "\n";
@@ -628,12 +625,11 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
os << "\nTop Offenders:\n";
bool printedAttention = false;
for (const PlannedPhysicalSlot *slot : ArrayRef(singleUseSlots).take_front(kSummaryOffenderLimit)) {
const LocalAllocInterval &interval = intervals[slot->intervalIndices.front()];
for (const PlannedPhysicalSlot* slot : ArrayRef(singleUseSlots).take_front(kSummaryOffenderLimit)) {
const LocalAllocInterval& interval = intervals[slot->intervalIndices.front()];
printedAttention = true;
os << " slot #" << slot->id << " is single-use"
<< " size=" << formatReportMemory(slot->size)
<< " interval=#" << interval.id
<< " size=" << formatReportMemory(slot->size) << " interval=#" << interval.id
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " first=" << summarizeOperation(interval.firstTouchOp, 40)
<< " last=" << summarizeOperation(interval.lastTouchOp, 40)
@@ -641,28 +637,26 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") << "\n";
}
size_t fallbackPrinted = 0;
for (const LocalAllocInterval &interval : intervals) {
for (const LocalAllocInterval& interval : intervals) {
if (!(interval.startUsedAllocFallback || interval.endUsedFallback) || fallbackPrinted >= kSummaryOffenderLimit)
continue;
printedAttention = true;
++fallbackPrinted;
os << " fallback interval #" << interval.id
<< " size=" << formatReportMemory(interval.size)
os << " fallback interval #" << interval.id << " size=" << formatReportMemory(interval.size)
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " reason: " << (interval.fallbackReason.empty() ? "<none>" : interval.fallbackReason) << "\n";
}
size_t nestedPrinted = 0;
for (const LocalAllocInterval &interval : intervals) {
for (const LocalAllocInterval& interval : intervals) {
if (nestedPrinted >= kSummaryOffenderLimit)
break;
if (!(interval.insideNestedRegion && slots[interval.slotPlanIndex].intervalIndices.size() == 1))
continue;
printedAttention = true;
++nestedPrinted;
os << " nested single-use interval #" << interval.id
<< " slot #" << interval.physicalSlotId
<< " size=" << formatReportMemory(interval.size)
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " nested single-use interval #" << interval.id << " slot #" << interval.physicalSlotId
<< " size=" << formatReportMemory(interval.size) << " value=" << summarizeValue(interval.key.value, 56)
<< "\n";
os << " hint: move or sink this alloc inside the nested region if the IR allows it.\n";
}
if (!printedAttention)
@@ -670,18 +664,17 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
if (reportLevel == PimMemoryReportFull) {
os << "\nSlot Reuse:\n";
for (const PlannedPhysicalSlot &slot : slots) {
for (const PlannedPhysicalSlot& slot : slots) {
uint64_t slotLogicalBytes = getSlotLogicalBytes(slot, intervals);
os << " slot #" << slot.id << " addr=" << slot.address << " size=" << formatReportMemory(slot.size) << " ("
<< slot.size << ")"
<< " intervals=" << slot.intervalIndices.size()
<< " logical_sum=" << formatReportMemory(slotLogicalBytes) << "\n";
<< " intervals=" << slot.intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
<< "\n";
for (size_t intervalIndex : slot.intervalIndices) {
const LocalAllocInterval &interval = intervals[intervalIndex];
const LocalAllocInterval& interval = intervals[intervalIndex];
mlir::Value allocValue = interval.key.value;
os << " [" << interval.start << "," << interval.end << "]"
<< " #" << interval.id
<< " logical=" << formatReportMemory(interval.size)
<< " #" << interval.id << " logical=" << formatReportMemory(interval.size)
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
<< " first=" << summarizeOperation(interval.firstTouchOp, 48)
@@ -693,16 +686,14 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
if (reportLevel == PimMemoryReportFull) {
os << "\nInterval Details:\n";
for (const LocalAllocInterval &interval : intervals) {
const PlannedPhysicalSlot &slot = slots[interval.slotPlanIndex];
for (const LocalAllocInterval& interval : intervals) {
const PlannedPhysicalSlot& slot = slots[interval.slotPlanIndex];
mlir::Value allocValue = interval.key.value;
Operation *definingOp = allocValue.getDefiningOp();
os << " #" << interval.id
<< " slot=" << slot.id
<< " live=[" << interval.start << "," << interval.end << "]"
Operation* definingOp = allocValue.getDefiningOp();
os << " #" << interval.id << " slot=" << slot.id << " live=[" << interval.start << "," << interval.end << "]"
<< " logical=" << formatReportMemory(interval.size)
<< " slot_size=" << formatReportMemory(interval.physicalSlotSize)
<< " addr=" << interval.assignedAddress << "\n";
<< " slot_size=" << formatReportMemory(interval.physicalSlotSize) << " addr=" << interval.assignedAddress
<< "\n";
os << " value=" << summarizeValue(allocValue, 88) << "\n";
os << " type=" << allocValue.getType() << "\n";
os << " loc="
@@ -731,7 +722,7 @@ MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp,
os << " fallback_reason=" << interval.fallbackReason << "\n";
if (!interval.aliasesFollowed.empty()) {
os << " aliases_followed=" << interval.aliasesFollowed.size() << "\n";
for (const std::string &alias : interval.aliasesFollowed)
for (const std::string& alias : interval.aliasesFollowed)
os << " - " << abbreviate(collapseWhitespace(alias), 108) << "\n";
}
}
+6 -6
View File
@@ -21,10 +21,10 @@ struct LocalAllocInterval {
uint64_t start = 0;
uint64_t end = 0;
size_t size = 0;
mlir::Operation *startOp = nullptr;
mlir::Operation *endOp = nullptr;
mlir::Operation *firstTouchOp = nullptr;
mlir::Operation *lastTouchOp = nullptr;
mlir::Operation* startOp = nullptr;
mlir::Operation* endOp = nullptr;
mlir::Operation* firstTouchOp = nullptr;
mlir::Operation* lastTouchOp = nullptr;
uint64_t firstTouchPosition = 0;
uint64_t lastTouchPosition = 0;
bool startUsedAllocFallback = false;
@@ -48,12 +48,12 @@ struct PlannedPhysicalSlot {
llvm::SmallVector<size_t, 8> intervalIndices;
};
llvm::SmallVector<LocalAllocInterval, 0> buildLocalAllocIntervals(mlir::Operation *coreLikeOp,
llvm::SmallVector<LocalAllocInterval, 0> buildLocalAllocIntervals(mlir::Operation* coreLikeOp,
std::optional<unsigned> lane);
llvm::SmallVector<PlannedPhysicalSlot, 0> planPhysicalSlots(llvm::MutableArrayRef<LocalAllocInterval> intervals);
MemoryPlanArtifacts buildMemoryPlanArtifacts(mlir::Operation *coreLikeOp,
MemoryPlanArtifacts buildMemoryPlanArtifacts(mlir::Operation* coreLikeOp,
std::optional<unsigned> lane,
llvm::ArrayRef<LocalAllocInterval> intervals,
llvm::ArrayRef<PlannedPhysicalSlot> slots,
+1 -2
View File
@@ -19,8 +19,7 @@ using namespace mlir;
namespace onnx_mlir {
namespace {} // namespace
WeightEmissionResult
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
WeightEmissionResult createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
+2 -2
View File
@@ -23,7 +23,7 @@ struct WeightEmissionResult {
uint64_t totalWeightBytes = 0;
};
WeightEmissionResult
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
WeightEmissionResult createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -19,7 +19,7 @@ namespace pim {
namespace {
static bool isSupportedAliasOp(Operation *op) {
static bool isSupportedAliasOp(Operation* op) {
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
}
@@ -32,20 +32,20 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
}
static Operation *getTopLevelAncestorInBlock(Operation *op, Block &block) {
Operation *current = op;
static Operation* getTopLevelAncestorInBlock(Operation* op, Block& block) {
Operation* current = op;
while (current && current->getBlock() != &block)
current = current->getParentOp();
return current;
}
static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis);
static void analyzeBlock(Block& block, MemoryCoalescingAnalysis& analysis);
static FailureOr<uint64_t>
getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap<Operation *, uint64_t> &opOrder) {
getLastUseInstruction(memref::AllocOp allocOp, Block& scopeBlock, const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp);
SmallPtrSet<Value, 16> visitedValues;
SmallPtrSet<Operation *, 16> visitedUsers;
SmallPtrSet<Operation*, 16> visitedUsers;
SmallVector<Value> pendingValues;
pendingValues.push_back(allocOp.getResult());
@@ -54,7 +54,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
if (!visitedValues.insert(value).second)
continue;
for (Operation *user : value.getUsers()) {
for (Operation* user : value.getUsers()) {
if (!visitedUsers.insert(user).second)
continue;
@@ -63,7 +63,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) {
OpOperand *tiedOperand = dpsOp.getTiedOpOperand(result);
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
if (tiedOperand && tiedOperand->get() == value)
pendingValues.push_back(result);
}
@@ -87,7 +87,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
pendingValues.push_back(forOp.getResult(index));
}
Operation *orderedUser = getTopLevelAncestorInBlock(user, scopeBlock);
Operation* orderedUser = getTopLevelAncestorInBlock(user, scopeBlock);
if (!orderedUser)
return failure();
@@ -101,21 +101,21 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
return endInstruction;
}
static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis) {
for (Operation &op : block)
for (Region &region : op.getRegions())
for (Block &nestedBlock : region)
static void analyzeBlock(Block& block, MemoryCoalescingAnalysis& analysis) {
for (Operation& op : block)
for (Region& region : op.getRegions())
for (Block& nestedBlock : region)
analyzeBlock(nestedBlock, analysis);
DenseMap<Operation *, uint64_t> opOrder;
DenseMap<Operation*, uint64_t> opOrder;
uint64_t nextInstruction = 0;
for (Operation &op : block)
for (Operation& op : block)
opOrder.try_emplace(&op, nextInstruction++);
MemoryCoalescingBlockAnalysis blockAnalysis;
blockAnalysis.block = &block;
for (Operation &op : block) {
for (Operation& op : block) {
auto allocOp = dyn_cast<memref::AllocOp>(&op);
if (!allocOp)
continue;
@@ -145,12 +145,12 @@ static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis) {
uint64_t MemoryCoalescingAnalysis::getCandidateCount() const {
uint64_t total = 0;
for (const MemoryCoalescingBlockAnalysis &block : blocks)
for (const MemoryCoalescingBlockAnalysis& block : blocks)
total += block.candidates.size();
return total;
}
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation *coreLikeOp) {
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation* coreLikeOp) {
MemoryCoalescingAnalysis analysis;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return analysis;
@@ -160,15 +160,15 @@ MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation *coreLikeOp
}
MemoryCoalescingStats
coalesceMemory(Operation *coreLikeOp, const MemoryCoalescingAnalysis &analysis, RewriterBase &rewriter) {
coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) {
(void) coreLikeOp;
MemoryCoalescingStats stats;
stats.skippedAllocations = analysis.skippedAllocations;
for (const MemoryCoalescingBlockAnalysis &blockAnalysis : analysis.blocks) {
for (const MemoryCoalescingBlockAnalysis& blockAnalysis : analysis.blocks) {
auto candidates = blockAnalysis.candidates;
llvm::sort(candidates, [](const AllocationCandidate &lhs, const AllocationCandidate &rhs) {
llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) {
if (lhs.startInstruction != rhs.startInstruction)
return lhs.startInstruction < rhs.startInstruction;
return lhs.endInstruction < rhs.endInstruction;
@@ -182,7 +182,7 @@ coalesceMemory(Operation *coreLikeOp, const MemoryCoalescingAnalysis &analysis,
SmallVector<ActiveStorage> active;
SmallVector<memref::AllocOp> freeList;
for (AllocationCandidate &candidate : candidates) {
for (AllocationCandidate& candidate : candidates) {
for (auto it = active.begin(); it != active.end();) {
if (it->endInstruction < candidate.startInstruction) {
freeList.push_back(it->root);
@@ -10,14 +10,14 @@ namespace pim {
struct AllocationCandidate {
mlir::memref::AllocOp alloc;
mlir::Block *scopeBlock = nullptr;
mlir::Block* scopeBlock = nullptr;
uint64_t startInstruction = 0;
uint64_t endInstruction = 0;
uint64_t sizeBytes = 0;
};
struct MemoryCoalescingBlockAnalysis {
mlir::Block *block = nullptr;
mlir::Block* block = nullptr;
llvm::SmallVector<AllocationCandidate> candidates;
uint64_t skippedAllocations = 0;
};
@@ -448,7 +448,10 @@ collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> l
auto result = dyn_cast<OpResult>(value);
if (!result)
return {};
keys.push_back({{compute.getOperation(), 0, 1}, result.getResultNumber()});
keys.push_back({
{compute.getOperation(), 0, 1},
result.getResultNumber()
});
return keys;
}
@@ -476,8 +479,8 @@ collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> l
return keys;
}
std::optional<ProducerKey>
getInputRequestProducerKey(Value value, std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
std::optional<ProducerKey> getInputRequestProducerKey(Value value,
std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
// Input resolution may request a whole-batch key for scalar consumers that read
// a complete resultful compute_batch value.
Operation* definingOp = value.getDefiningOp();
@@ -511,7 +514,10 @@ getInputRequestProducerKey(Value value, std::optional<ComputeInstance> logicalCo
auto result = dyn_cast<OpResult>(value);
if (!result)
return std::nullopt;
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()};
return ProducerKey {
{compute.getOperation(), 0, 1},
result.getResultNumber()
};
}
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
@@ -675,7 +681,7 @@ LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState&
return state.func.emitError("materialized class logical slot source op mismatch");
if (isa<SpatComputeBatch>(referenceInstance.op) != isa<SpatComputeBatch>(currentInstance.op))
return state.func.emitError("materialized class logical slot batch/scalar mismatch");
(void)slot;
(void) slot;
}
}
}
@@ -1707,7 +1713,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
continue;
}
(void)logicalSlot;
(void) logicalSlot;
descriptor.offsetsByLane[targetLane].push_back(*offset);
}
@@ -3268,9 +3274,7 @@ FailureOr<Value> materializeIndexedBatchRunReceive(MaterializerState& state,
}
FailureOr<SmallVector<Value, 4>>
cloneInstanceBody(MaterializerState& state,
MaterializedClass& targetClass,
ArrayRef<ComputeInstance> peers) {
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
assert(!peers.empty() && "expected at least one peer instance");
const ComputeInstance& instance = peers.front();
Operation* sourceOp = instance.op;
@@ -3620,8 +3624,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart);
return cloneBatchBodyForLane(
state, targetClass, item, laneValue, group.resultIndices, {});
return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {});
}
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
@@ -3733,8 +3736,7 @@ FailureOr<MaterializationRun> collectBatchMaterializationRun(MaterializerState&
if (state.materializedLogicalSlots.contains(classSlot))
break;
FailureOr<SmallVector<ComputeInstance, 8>> peers =
getMaterializationRunSlotPeers(state, targetClass, slot);
FailureOr<SmallVector<ComputeInstance, 8>> peers = getMaterializationRunSlotPeers(state, targetClass, slot);
if (failed(peers) || peers->empty())
break;
@@ -3818,10 +3820,9 @@ bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state,
const OutputDestinationGroup& group) {
for (size_t resultIndex : group.resultIndices) {
for (const MaterializationRunSlot& slot : run) {
for (const ComputeInstance& peer : slot.peers) {
for (const ComputeInstance& peer : slot.peers)
if (hasSameClassConsumer(state, {peer, resultIndex}, classId))
return true;
}
}
}
@@ -4020,7 +4021,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
plan.messages.targetCoreIds.reserve(messageCount);
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
(void)slotIndex;
(void) slotIndex;
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
if (failed(checkedSourceCpu))
@@ -4236,7 +4237,8 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
return success();
if (isa<SpatComputeBatch>(instance.op)) {
FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
FailureOr<MaterializationRun> run =
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
if (succeeded(run)) {
if (!targetClass.isBatch)
@@ -169,8 +169,8 @@ uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance&
uint32_t lhsEnd = lhs.laneStart + lhs.laneCount;
uint32_t rhsEnd = rhs.laneStart + rhs.laneCount;
return std::max(lhs.laneStart, rhs.laneStart) < std::min(lhsEnd, rhsEnd)
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
: 0;
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
: 0;
}
Cost scaleTransferCostByLaneCount(Cost totalCost, uint32_t totalLaneCount, uint32_t fragmentLaneCount) {
@@ -197,16 +197,13 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
return producers;
}
if (isa<SpatComputeBatch>(consumerInstance.op)) {
if (isa<SpatComputeBatch>(consumerInstance.op))
for (ComputeInstance instance :
getBatchChunksForRange(batch, consumerInstance.laneStart, consumerInstance.laneCount))
producers.push_back({instance, 0});
}
else {
for (ComputeInstance instance :
getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
else
for (ComputeInstance instance : getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
producers.push_back({instance, 0});
}
return producers;
}
@@ -217,17 +214,18 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
}
if (auto compute = dyn_cast<SpatCompute>(op)) {
producers.push_back({ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())});
producers.push_back({
ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
});
return producers;
}
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
if (batch.getNumResults() != 0) {
uint32_t laneStart = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneStart : 0;
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op)
? consumerInstance.laneCount
: static_cast<uint32_t>(batch.getLaneCount());
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneCount
: static_cast<uint32_t>(batch.getLaneCount());
for (ComputeInstance instance : getBatchChunksForRange(batch, laneStart, laneCount))
producers.push_back({instance, 0});
return producers;
@@ -242,7 +240,9 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
return producers;
}
Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstance, const ProducerValueRef& producerRef) {
Cost getProducerTransferCost(Value input,
const ComputeInstance& consumerInstance,
const ProducerValueRef& producerRef) {
Cost transferCost = getInputTransferCost(consumerInstance, input);
auto producerBatch = dyn_cast<SpatComputeBatch>(producerRef.instance.op);
if (!producerBatch || producerBatch.getNumResults() == 0)
@@ -256,9 +256,8 @@ Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstanc
}
}
return scaleTransferCostByLaneCount(transferCost,
static_cast<uint32_t>(producerBatch.getLaneCount()),
producerRef.instance.laneCount);
return scaleTransferCostByLaneCount(
transferCost, static_cast<uint32_t>(producerBatch.getLaneCount()), producerRef.instance.laneCount);
}
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
@@ -64,9 +64,8 @@ ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
return getBatchChunkForIndex(batch, getBatchChunkIndexForLane(batch.getLaneCount(), lane));
}
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount) {
llvm::SmallVector<ComputeInstance, 4>
getBatchChunksForRange(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount) {
llvm::SmallVector<ComputeInstance, 4> chunks;
if (laneCount == 0)
return chunks;
@@ -32,9 +32,8 @@ BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex);
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount);
llvm::SmallVector<ComputeInstance, 4>
getBatchChunksForRange(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
const ComputeInstance* consumerInstance = nullptr);