This commit is contained in:
@@ -19,7 +19,7 @@ namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isSupportedAliasOp(Operation *op) {
|
||||
static bool isSupportedAliasOp(Operation* op) {
|
||||
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
@@ -32,20 +32,20 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
|
||||
}
|
||||
|
||||
static Operation *getTopLevelAncestorInBlock(Operation *op, Block &block) {
|
||||
Operation *current = op;
|
||||
static Operation* getTopLevelAncestorInBlock(Operation* op, Block& block) {
|
||||
Operation* current = op;
|
||||
while (current && current->getBlock() != &block)
|
||||
current = current->getParentOp();
|
||||
return current;
|
||||
}
|
||||
|
||||
static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis);
|
||||
static void analyzeBlock(Block& block, MemoryCoalescingAnalysis& analysis);
|
||||
|
||||
static FailureOr<uint64_t>
|
||||
getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap<Operation *, uint64_t> &opOrder) {
|
||||
getLastUseInstruction(memref::AllocOp allocOp, Block& scopeBlock, const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
uint64_t endInstruction = opOrder.lookup(allocOp);
|
||||
SmallPtrSet<Value, 16> visitedValues;
|
||||
SmallPtrSet<Operation *, 16> visitedUsers;
|
||||
SmallPtrSet<Operation*, 16> visitedUsers;
|
||||
SmallVector<Value> pendingValues;
|
||||
pendingValues.push_back(allocOp.getResult());
|
||||
|
||||
@@ -54,7 +54,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
|
||||
if (!visitedValues.insert(value).second)
|
||||
continue;
|
||||
|
||||
for (Operation *user : value.getUsers()) {
|
||||
for (Operation* user : value.getUsers()) {
|
||||
if (!visitedUsers.insert(user).second)
|
||||
continue;
|
||||
|
||||
@@ -63,7 +63,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
for (OpResult result : user->getResults()) {
|
||||
OpOperand *tiedOperand = dpsOp.getTiedOpOperand(result);
|
||||
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
|
||||
if (tiedOperand && tiedOperand->get() == value)
|
||||
pendingValues.push_back(result);
|
||||
}
|
||||
@@ -87,7 +87,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
|
||||
Operation *orderedUser = getTopLevelAncestorInBlock(user, scopeBlock);
|
||||
Operation* orderedUser = getTopLevelAncestorInBlock(user, scopeBlock);
|
||||
if (!orderedUser)
|
||||
return failure();
|
||||
|
||||
@@ -101,21 +101,21 @@ getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap
|
||||
return endInstruction;
|
||||
}
|
||||
|
||||
static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis) {
|
||||
for (Operation &op : block)
|
||||
for (Region ®ion : op.getRegions())
|
||||
for (Block &nestedBlock : region)
|
||||
static void analyzeBlock(Block& block, MemoryCoalescingAnalysis& analysis) {
|
||||
for (Operation& op : block)
|
||||
for (Region& region : op.getRegions())
|
||||
for (Block& nestedBlock : region)
|
||||
analyzeBlock(nestedBlock, analysis);
|
||||
|
||||
DenseMap<Operation *, uint64_t> opOrder;
|
||||
DenseMap<Operation*, uint64_t> opOrder;
|
||||
uint64_t nextInstruction = 0;
|
||||
for (Operation &op : block)
|
||||
for (Operation& op : block)
|
||||
opOrder.try_emplace(&op, nextInstruction++);
|
||||
|
||||
MemoryCoalescingBlockAnalysis blockAnalysis;
|
||||
blockAnalysis.block = █
|
||||
|
||||
for (Operation &op : block) {
|
||||
for (Operation& op : block) {
|
||||
auto allocOp = dyn_cast<memref::AllocOp>(&op);
|
||||
if (!allocOp)
|
||||
continue;
|
||||
@@ -145,12 +145,12 @@ static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis) {
|
||||
|
||||
uint64_t MemoryCoalescingAnalysis::getCandidateCount() const {
|
||||
uint64_t total = 0;
|
||||
for (const MemoryCoalescingBlockAnalysis &block : blocks)
|
||||
for (const MemoryCoalescingBlockAnalysis& block : blocks)
|
||||
total += block.candidates.size();
|
||||
return total;
|
||||
}
|
||||
|
||||
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation *coreLikeOp) {
|
||||
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation* coreLikeOp) {
|
||||
MemoryCoalescingAnalysis analysis;
|
||||
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
|
||||
return analysis;
|
||||
@@ -160,15 +160,15 @@ MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation *coreLikeOp
|
||||
}
|
||||
|
||||
MemoryCoalescingStats
|
||||
coalesceMemory(Operation *coreLikeOp, const MemoryCoalescingAnalysis &analysis, RewriterBase &rewriter) {
|
||||
coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) {
|
||||
(void) coreLikeOp;
|
||||
|
||||
MemoryCoalescingStats stats;
|
||||
stats.skippedAllocations = analysis.skippedAllocations;
|
||||
|
||||
for (const MemoryCoalescingBlockAnalysis &blockAnalysis : analysis.blocks) {
|
||||
for (const MemoryCoalescingBlockAnalysis& blockAnalysis : analysis.blocks) {
|
||||
auto candidates = blockAnalysis.candidates;
|
||||
llvm::sort(candidates, [](const AllocationCandidate &lhs, const AllocationCandidate &rhs) {
|
||||
llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) {
|
||||
if (lhs.startInstruction != rhs.startInstruction)
|
||||
return lhs.startInstruction < rhs.startInstruction;
|
||||
return lhs.endInstruction < rhs.endInstruction;
|
||||
@@ -182,7 +182,7 @@ coalesceMemory(Operation *coreLikeOp, const MemoryCoalescingAnalysis &analysis,
|
||||
SmallVector<ActiveStorage> active;
|
||||
SmallVector<memref::AllocOp> freeList;
|
||||
|
||||
for (AllocationCandidate &candidate : candidates) {
|
||||
for (AllocationCandidate& candidate : candidates) {
|
||||
for (auto it = active.begin(); it != active.end();) {
|
||||
if (it->endInstruction < candidate.startInstruction) {
|
||||
freeList.push_back(it->root);
|
||||
|
||||
@@ -10,14 +10,14 @@ namespace pim {
|
||||
|
||||
struct AllocationCandidate {
|
||||
mlir::memref::AllocOp alloc;
|
||||
mlir::Block *scopeBlock = nullptr;
|
||||
mlir::Block* scopeBlock = nullptr;
|
||||
uint64_t startInstruction = 0;
|
||||
uint64_t endInstruction = 0;
|
||||
uint64_t sizeBytes = 0;
|
||||
};
|
||||
|
||||
struct MemoryCoalescingBlockAnalysis {
|
||||
mlir::Block *block = nullptr;
|
||||
mlir::Block* block = nullptr;
|
||||
llvm::SmallVector<AllocationCandidate> candidates;
|
||||
uint64_t skippedAllocations = 0;
|
||||
};
|
||||
|
||||
@@ -448,7 +448,10 @@ collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> l
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return {};
|
||||
keys.push_back({{compute.getOperation(), 0, 1}, result.getResultNumber()});
|
||||
keys.push_back({
|
||||
{compute.getOperation(), 0, 1},
|
||||
result.getResultNumber()
|
||||
});
|
||||
return keys;
|
||||
}
|
||||
|
||||
@@ -476,8 +479,8 @@ collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> l
|
||||
return keys;
|
||||
}
|
||||
|
||||
std::optional<ProducerKey>
|
||||
getInputRequestProducerKey(Value value, std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
|
||||
std::optional<ProducerKey> getInputRequestProducerKey(Value value,
|
||||
std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
|
||||
// Input resolution may request a whole-batch key for scalar consumers that read
|
||||
// a complete resultful compute_batch value.
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
@@ -511,7 +514,10 @@ getInputRequestProducerKey(Value value, std::optional<ComputeInstance> logicalCo
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return std::nullopt;
|
||||
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()};
|
||||
return ProducerKey {
|
||||
{compute.getOperation(), 0, 1},
|
||||
result.getResultNumber()
|
||||
};
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
|
||||
@@ -675,7 +681,7 @@ LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState&
|
||||
return state.func.emitError("materialized class logical slot source op mismatch");
|
||||
if (isa<SpatComputeBatch>(referenceInstance.op) != isa<SpatComputeBatch>(currentInstance.op))
|
||||
return state.func.emitError("materialized class logical slot batch/scalar mismatch");
|
||||
(void)slot;
|
||||
(void) slot;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1707,7 +1713,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
continue;
|
||||
}
|
||||
|
||||
(void)logicalSlot;
|
||||
(void) logicalSlot;
|
||||
descriptor.offsetsByLane[targetLane].push_back(*offset);
|
||||
}
|
||||
|
||||
@@ -3268,9 +3274,7 @@ FailureOr<Value> materializeIndexedBatchRunReceive(MaterializerState& state,
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
cloneInstanceBody(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
ArrayRef<ComputeInstance> peers) {
|
||||
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
|
||||
assert(!peers.empty() && "expected at least one peer instance");
|
||||
const ComputeInstance& instance = peers.front();
|
||||
Operation* sourceOp = instance.op;
|
||||
@@ -3620,8 +3624,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
|
||||
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart);
|
||||
return cloneBatchBodyForLane(
|
||||
state, targetClass, item, laneValue, group.resultIndices, {});
|
||||
return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {});
|
||||
}
|
||||
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
@@ -3733,8 +3736,7 @@ FailureOr<MaterializationRun> collectBatchMaterializationRun(MaterializerState&
|
||||
if (state.materializedLogicalSlots.contains(classSlot))
|
||||
break;
|
||||
|
||||
FailureOr<SmallVector<ComputeInstance, 8>> peers =
|
||||
getMaterializationRunSlotPeers(state, targetClass, slot);
|
||||
FailureOr<SmallVector<ComputeInstance, 8>> peers = getMaterializationRunSlotPeers(state, targetClass, slot);
|
||||
if (failed(peers) || peers->empty())
|
||||
break;
|
||||
|
||||
@@ -3818,10 +3820,9 @@ bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state,
|
||||
const OutputDestinationGroup& group) {
|
||||
for (size_t resultIndex : group.resultIndices) {
|
||||
for (const MaterializationRunSlot& slot : run) {
|
||||
for (const ComputeInstance& peer : slot.peers) {
|
||||
for (const ComputeInstance& peer : slot.peers)
|
||||
if (hasSameClassConsumer(state, {peer, resultIndex}, classId))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4020,7 +4021,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
|
||||
plan.messages.targetCoreIds.reserve(messageCount);
|
||||
|
||||
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
|
||||
(void)slotIndex;
|
||||
(void) slotIndex;
|
||||
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
|
||||
if (failed(checkedSourceCpu))
|
||||
@@ -4236,7 +4237,8 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
|
||||
return success();
|
||||
|
||||
if (isa<SpatComputeBatch>(instance.op)) {
|
||||
FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
|
||||
FailureOr<MaterializationRun> run =
|
||||
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
|
||||
|
||||
if (succeeded(run)) {
|
||||
if (!targetClass.isBatch)
|
||||
|
||||
@@ -169,8 +169,8 @@ uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance&
|
||||
uint32_t lhsEnd = lhs.laneStart + lhs.laneCount;
|
||||
uint32_t rhsEnd = rhs.laneStart + rhs.laneCount;
|
||||
return std::max(lhs.laneStart, rhs.laneStart) < std::min(lhsEnd, rhsEnd)
|
||||
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
|
||||
: 0;
|
||||
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
|
||||
: 0;
|
||||
}
|
||||
|
||||
Cost scaleTransferCostByLaneCount(Cost totalCost, uint32_t totalLaneCount, uint32_t fragmentLaneCount) {
|
||||
@@ -197,16 +197,13 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
|
||||
return producers;
|
||||
}
|
||||
|
||||
if (isa<SpatComputeBatch>(consumerInstance.op)) {
|
||||
if (isa<SpatComputeBatch>(consumerInstance.op))
|
||||
for (ComputeInstance instance :
|
||||
getBatchChunksForRange(batch, consumerInstance.laneStart, consumerInstance.laneCount))
|
||||
producers.push_back({instance, 0});
|
||||
}
|
||||
else {
|
||||
for (ComputeInstance instance :
|
||||
getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
|
||||
else
|
||||
for (ComputeInstance instance : getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
|
||||
producers.push_back({instance, 0});
|
||||
}
|
||||
return producers;
|
||||
}
|
||||
|
||||
@@ -217,17 +214,18 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
|
||||
}
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
producers.push_back({ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())});
|
||||
producers.push_back({
|
||||
ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||
});
|
||||
return producers;
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||
if (batch.getNumResults() != 0) {
|
||||
uint32_t laneStart = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneStart : 0;
|
||||
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op)
|
||||
? consumerInstance.laneCount
|
||||
: static_cast<uint32_t>(batch.getLaneCount());
|
||||
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneCount
|
||||
: static_cast<uint32_t>(batch.getLaneCount());
|
||||
for (ComputeInstance instance : getBatchChunksForRange(batch, laneStart, laneCount))
|
||||
producers.push_back({instance, 0});
|
||||
return producers;
|
||||
@@ -242,7 +240,9 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
|
||||
return producers;
|
||||
}
|
||||
|
||||
Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstance, const ProducerValueRef& producerRef) {
|
||||
Cost getProducerTransferCost(Value input,
|
||||
const ComputeInstance& consumerInstance,
|
||||
const ProducerValueRef& producerRef) {
|
||||
Cost transferCost = getInputTransferCost(consumerInstance, input);
|
||||
auto producerBatch = dyn_cast<SpatComputeBatch>(producerRef.instance.op);
|
||||
if (!producerBatch || producerBatch.getNumResults() == 0)
|
||||
@@ -256,9 +256,8 @@ Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstanc
|
||||
}
|
||||
}
|
||||
|
||||
return scaleTransferCostByLaneCount(transferCost,
|
||||
static_cast<uint32_t>(producerBatch.getLaneCount()),
|
||||
producerRef.instance.laneCount);
|
||||
return scaleTransferCostByLaneCount(
|
||||
transferCost, static_cast<uint32_t>(producerBatch.getLaneCount()), producerRef.instance.laneCount);
|
||||
}
|
||||
|
||||
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
|
||||
|
||||
+2
-3
@@ -64,9 +64,8 @@ ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
return getBatchChunkForIndex(batch, getBatchChunkIndexForLane(batch.getLaneCount(), lane));
|
||||
}
|
||||
|
||||
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
|
||||
uint32_t laneStart,
|
||||
uint32_t laneCount) {
|
||||
llvm::SmallVector<ComputeInstance, 4>
|
||||
getBatchChunksForRange(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount) {
|
||||
llvm::SmallVector<ComputeInstance, 4> chunks;
|
||||
if (laneCount == 0)
|
||||
return chunks;
|
||||
|
||||
+2
-3
@@ -32,9 +32,8 @@ BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex);
|
||||
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane);
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
|
||||
uint32_t laneStart,
|
||||
uint32_t laneCount);
|
||||
llvm::SmallVector<ComputeInstance, 4>
|
||||
getBatchChunksForRange(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount);
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
|
||||
const ComputeInstance* consumerInstance = nullptr);
|
||||
|
||||
Reference in New Issue
Block a user