fix dcp merge bug
Some checks failed
Validate Operations / validate-operations (push) Failing after 15m54s

This commit is contained in:
NiccoloN
2026-05-04 15:58:14 +02:00
parent 5b9bb0c191
commit bdacb9871d
3 changed files with 18 additions and 7 deletions

View File

@@ -5,6 +5,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -412,6 +413,10 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("compute_batch core_id array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
return emitError("compute_batch core_id values must be positive");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch core_id values must be distinct");
}
Block& block = getBody().front();

View File

@@ -56,12 +56,10 @@ struct WindowScheduleResult {
size_t maxMergeGroupSize = 0;
};
constexpr CPU kDefaultMaxCpuCount = 1000;
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return static_cast<size_t>(kDefaultMaxCpuCount);
return std::numeric_limits<size_t>::max();
}
size_t getBatchChunkTargetCount(int32_t laneCount) {

View File

@@ -54,10 +54,9 @@ struct ProducerValueRef {
std::optional<ProducerValueRef> getProducerValueRef(Value value);
static size_t getFastPathCpuBudget() {
constexpr size_t kDefaultMaxCpuCount = 1000;
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return kDefaultMaxCpuCount;
return std::numeric_limits<size_t>::max();
}
static size_t getBatchChunkTargetCount(int32_t laneCount) {
@@ -670,6 +669,9 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
continue;
SmallVector<SpatCompute> group {anchor};
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
if (auto coreId = getComputeCoreId(anchor))
usedCoreIds.insert(*coreId);
if (!anchor.getResults().empty())
continue;
for (size_t candidateIndex = index + 1; candidateIndex < computes.size(); ++candidateIndex) {
@@ -680,8 +682,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
continue;
if (!candidate.getResults().empty())
continue;
if (areEquivalentForRebatch(anchor, candidate))
group.push_back(candidate);
if (!areEquivalentForRebatch(anchor, candidate))
continue;
if (auto coreId = getComputeCoreId(candidate))
if (!usedCoreIds.insert(*coreId).second)
continue;
group.push_back(candidate);
}
if (group.size() <= 1)