remove dead logic

This commit is contained in:
NiccoloN
2026-05-19 12:23:01 +02:00
parent e263e05f56
commit a103ba328b
4 changed files with 21 additions and 81 deletions
@@ -9,7 +9,6 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional>
#include <optional> #include <optional>
#include <utility> #include <utility>
@@ -54,7 +53,6 @@ public:
replaceExternalUses(); replaceExternalUses();
if (failed(eraseOldScheduledOps())) if (failed(eraseOldScheduledOps()))
return failure(); return failure();
moveExternalUsersBeforeReturn();
return success(); return success();
} }
@@ -97,18 +95,6 @@ private:
| static_cast<uint32_t>(channelInfo.targetCoreId); | static_cast<uint32_t>(channelInfo.targetCoreId);
} }
void collectExternalUsers(Operation* op) {
if (!externalUsersToMove.insert(op).second)
return;
for (Value result : op->getResults()) {
for (Operation* user : result.getUsers()) {
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
continue;
collectExternalUsers(user);
}
}
}
void collectScheduledTasks() { void collectScheduledTasks() {
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) { for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
oldComputeOps.insert(scheduledInstance.op); oldComputeOps.insert(scheduledInstance.op);
@@ -151,25 +137,22 @@ private:
auto& remoteInputs = remoteInputsByTask[task.computeInstance]; auto& remoteInputs = remoteInputsByTask[task.computeInstance];
remoteInputs.resize(taskInputs.size()); remoteInputs.resize(taskInputs.size());
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
auto producerRef = getProducerValueRef(input); if (auto producerRef = getProducerValueRef(input)) {
if (producerRef) {
auto producerIt = taskByComputeInstance.find(producerRef->instance); auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) { if (producerIt->second.cpu != cpu) {
if (producerIt->second.cpu != cpu) { ChannelInfo info {
ChannelInfo info { (*nextChannelId)++,
(*nextChannelId)++, static_cast<int32_t>(producerIt->second.cpu),
static_cast<int32_t>(producerIt->second.cpu), static_cast<int32_t>(cpu),
static_cast<int32_t>(cpu), };
}; remoteInputs[inputIndex] = info;
remoteInputs[inputIndex] = info; auto& perResultChannels = remoteSendsByTask[producerRef->instance];
auto& perResultChannels = remoteSendsByTask[producerRef->instance]; if (perResultChannels.empty())
if (perResultChannels.empty()) perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size()); perResultChannels[producerRef->resultIndex].push_back(
perResultChannels[producerRef->resultIndex].push_back( {info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
}
continue;
} }
continue;
} }
if (seenExternalInputsByCpu[cpu].insert(input).second) if (seenExternalInputsByCpu[cpu].insert(input).second)
cpuExternalInputs[cpu].push_back(input); cpuExternalInputs[cpu].push_back(input);
@@ -183,8 +166,6 @@ private:
if (oldComputeOps.contains(useOwner)) if (oldComputeOps.contains(useOwner))
continue; continue;
hasExternalUser = true; hasExternalUser = true;
if (!isa<func::ReturnOp>(useOwner))
collectExternalUsers(useOwner);
} }
if (hasExternalUser) if (hasExternalUser)
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex}); cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
@@ -407,7 +388,8 @@ private:
if (producerIt->second.cpu == cpu) { if (producerIt->second.cpu == cpu) {
auto producedIt = producedValuesByTask.find(producerRef->instance); auto producedIt = producedValuesByTask.find(producerRef->instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
task.computeInstance.op->emitOpError("missing local producer value during per-cpu merge materialization") task.computeInstance.op->emitOpError(
"missing local producer value during per-cpu merge materialization")
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu << " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
<< " producerLaneStart=" << producerRef->instance.laneStart << " producerLaneStart=" << producerRef->instance.laneStart
<< " producerLaneCount=" << producerRef->instance.laneCount; << " producerLaneCount=" << producerRef->instance.laneCount;
@@ -586,18 +568,6 @@ private:
return success(); return success();
} }
void moveExternalUsersBeforeReturn() {
SmallVector<Operation*> orderedUsersToMove;
for (Operation& op : func.getBody().front()) {
if (&op == returnOp.getOperation())
break;
if (externalUsersToMove.contains(&op))
orderedUsersToMove.push_back(&op);
}
for (Operation* op : orderedUsersToMove)
op->moveBefore(returnOp);
}
func::FuncOp func; func::FuncOp func;
const MergeScheduleResult* schedule = nullptr; const MergeScheduleResult* schedule = nullptr;
int64_t* nextChannelId = nullptr; int64_t* nextChannelId = nullptr;
@@ -610,7 +580,6 @@ private:
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu; DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
SmallVector<size_t> orderedCpus; SmallVector<size_t> orderedCpus;
DenseSet<size_t> seenCpus; DenseSet<size_t> seenCpus;
DenseSet<Operation*> externalUsersToMove;
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask; DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask; DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs; DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
@@ -13,7 +13,6 @@
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@@ -28,9 +27,7 @@
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <functional>
#include <iterator> #include <iterator>
#include <limits>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <tuple> #include <tuple>
@@ -39,13 +36,11 @@
#include "MaterializeMergeSchedule.hpp" #include "MaterializeMergeSchedule.hpp"
#include "PostMergeCompaction.hpp" #include "PostMergeCompaction.hpp"
#include "RegularOpCompaction.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/ComputeInstanceUtils.hpp"
#include "Scheduling/MergeSchedulingAnalysis.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -53,10 +48,8 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
using namespace onnx_mlir::compact_asm; using namespace onnx_mlir::compact_asm;
using ProducerValueRef = spatial::ProducerValueRef;
using SpatCompute = spatial::SpatCompute; using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch; using SpatComputeBatch = spatial::SpatComputeBatch;
using spatial::getOriginalSpatCompute;
using spatial::getProducerValueRef; using spatial::getProducerValueRef;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; } bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
@@ -303,7 +296,7 @@ void emitMotifProfile(func::FuncOp funcOp) {
} }
for (Value input : compute.getInputs()) { for (Value input : compute.getInputs()) {
auto parent = getOriginalSpatCompute(input.getDefiningOp()); auto parent = dyn_cast<SpatCompute>(input.getDefiningOp());
if (!parent || parent == compute) if (!parent || parent == compute)
continue; continue;
auto parentIt = computeToIndex.find(parent); auto parentIt = computeToIndex.find(parent);
@@ -22,7 +22,7 @@ size_t getBatchChunkTargetCount(int32_t laneCount) {
} }
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount()); size_t totalLanes = batch.getLaneCount();
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount; size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount; size_t largeChunkCount = totalLanes % chunkCount;
@@ -33,7 +33,7 @@ ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex)
} }
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount()); size_t totalLanes = batch.getLaneCount();
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount; size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount; size_t largeChunkCount = totalLanes % chunkCount;
@@ -47,32 +47,11 @@ ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
return getBatchChunkForIndex(batch, chunkIndex); return getBatchChunkForIndex(batch, chunkIndex);
} }
SpatCompute getOriginalSpatCompute(Operation *op) {
if (!op)
return {};
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
return dyn_cast<SpatCompute>(op);
}
std::optional<ProducerValueRef> getProducerValueRef(Value value) { std::optional<ProducerValueRef> getProducerValueRef(Value value) {
Operation *op = value.getDefiningOp(); Operation *op = value.getDefiningOp();
if (!op) if (!op)
return std::nullopt; return std::nullopt;
//TODO Extract Slice is not the only global non compute operation. There are other legal op
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource();
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto compute = dyn_cast<SpatCompute>(op)) { if (auto compute = dyn_cast<SpatCompute>(op)) {
return ProducerValueRef { return ProducerValueRef {
ComputeInstance {compute.getOperation(), 0, 1}, ComputeInstance {compute.getOperation(), 0, 1},
@@ -81,9 +60,9 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
} }
if (auto batch = dyn_cast<SpatComputeBatch>(op)) { if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()); uint32_t lane = cast<OpResult>(value).getResultNumber();
ComputeInstance instance = getBatchChunkForLane(batch, lane); ComputeInstance instance = getBatchChunkForLane(batch, lane);
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart); size_t resultIndex = lane - instance.laneStart;
return ProducerValueRef {instance, resultIndex}; return ProducerValueRef {instance, resultIndex};
} }
@@ -26,7 +26,6 @@ size_t getBatchChunkTargetCount(int32_t laneCount);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex); ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane); ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value); std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value); std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);