automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 16:39:56 +02:00
parent 4bdaa57656
commit 874a2f53e6
23 changed files with 136 additions and 198 deletions
+56 -85
View File
@@ -199,22 +199,20 @@ MemoryReportRow PimMemory::getReportRow() const {
}
void PimMemory::remove(mlir::Value val) {
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) {
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();)
if (it->first.value == val) {
auto eraseIt = it++;
ownedMemEntriesMap.erase(eraseIt);
}
else
++it;
}
for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();) {
for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();)
if (it->first.value == val) {
auto eraseIt = it++;
globalMemEntriesMap.erase(eraseIt);
}
else
++it;
}
}
MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const {
@@ -275,7 +273,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
return iter->second.address + resolvedAddress->byteOffset;
}
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) const {
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
const StaticValueKnowledge& knowledge) const {
value = resolveCachedAlias(value, knowledge);
auto compiledIt = compiledIndexExprs.find(value);
if (compiledIt == compiledIndexExprs.end()) {
@@ -826,7 +825,8 @@ class ScopedMapBindings {
llvm::SmallVector<std::pair<KeyTy, std::optional<ValueTy>>, 8> savedEntries;
public:
explicit ScopedMapBindings(MapTy& map) : map(map) {}
explicit ScopedMapBindings(MapTy& map)
: map(map) {}
void bind(const KeyTy& key, const ValueTy& value) {
auto it = map.find(key);
@@ -838,12 +838,11 @@ public:
}
~ScopedMapBindings() {
for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it) {
for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it)
if (it->second)
map[it->first] = *it->second;
else
map.erase(it->first);
}
}
};
@@ -929,9 +928,8 @@ static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
return failure();
}
static LogicalResult compileCoreEmissionPlan(Block& block,
Operation* weightOwner,
llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
static LogicalResult
compileCoreEmissionPlan(Block& block, Operation* weightOwner, llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
for (Operation& op : block) {
if (isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
@@ -982,15 +980,14 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
return success();
}
static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
PimCodeGen& coreCodeGen,
StaticValueKnowledge& knowledge,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
const StaticValueKnowledge&)>
resolveWeightSlot,
size_t& processedOperations,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
static LogicalResult executeCompiledCorePlan(
const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
PimCodeGen& coreCodeGen,
StaticValueKnowledge& knowledge,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
size_t& processedOperations,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
for (const CompiledCoreNode& node : plan) {
if (node.kind == CompiledCoreNode::Kind::Loop) {
auto lowerBound = node.lowerBound.evaluate(knowledge);
@@ -1010,8 +1007,13 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
aliasBindings.bind(iterArg, iterValue);
if (failed(executeCompiledCorePlan(
*node.loopBody, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount)))
if (failed(executeCompiledCorePlan(*node.loopBody,
coreCodeGen,
knowledge,
resolveWeightSlot,
processedOperations,
batchLane,
batchLaneCount)))
return failure();
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
@@ -1031,18 +1033,10 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
case CompiledCoreOpKind::Store:
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Lmv:
coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Receive:
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Send:
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Concat:
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Lmv: coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Receive: coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Send: coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Concat: coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Vmm:
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
@@ -1052,33 +1046,15 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
case CompiledCoreOpKind::Transpose:
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVAdd:
coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVSub:
coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVMul:
coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVMax:
coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVDMul:
coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VAvg:
coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VRelu:
coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VTanh:
coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VSigm:
coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVAdd: coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVSub: coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVMul: coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVMax: coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVDMul: coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VAvg: coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VRelu: coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VTanh: coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VSigm: coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VSoftmax:
coreCodeGen.codeGenVSoftmaxOp(cast<pim::PimVSoftmaxOp>(node.op), knowledge);
break;
@@ -1131,23 +1107,22 @@ static void aliasMaterializedHostGlobals(CoreLikeOpTy coreLikeOp,
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(Block& block,
PimCodeGen& coreCodeGen,
const StaticValueKnowledge& initialKnowledge,
Operation* weightOwner,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
const StaticValueKnowledge&)>
resolveWeightSlot,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
static int64_t codeGenCoreOps(
Block& block,
PimCodeGen& coreCodeGen,
const StaticValueKnowledge& initialKnowledge,
Operation* weightOwner,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
llvm::SmallVector<CompiledCoreNode, 32> plan;
if (failed(compileCoreEmissionPlan(block, weightOwner, plan)))
return -1;
size_t processedOperations = 0;
StaticValueKnowledge knowledge = initialKnowledge;
auto result =
executeCompiledCorePlan(plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
auto result = executeCompiledCorePlan(
plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
@@ -1219,9 +1194,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
SmallVector<size_t> jobIndices;
SmallVector<size_t> orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys());
llvm::sort(orderedOriginalCoreIds, [&](size_t lhs, size_t rhs) {
return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs);
});
llvm::sort(orderedOriginalCoreIds,
[&](size_t lhs, size_t rhs) { return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs); });
for (size_t originalCoreId : orderedOriginalCoreIds) {
CoreEmissionJob job;
job.coreLikeOp = coreBatchOp;
@@ -1236,9 +1210,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
++nextBatchReportId;
}
auto linkCoreWeights = [&](size_t coreId,
ArrayRef<std::string> weightFiles,
json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
auto linkCoreWeights =
[&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
@@ -1250,8 +1223,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")
<< "\nError:" << error.message() << '\n';
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message()
<< '\n';
return InvalidOutputFileAccess;
}
}
@@ -1294,8 +1267,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
errorCode = std::error_code();
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
if (errorCode) {
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message()
<< '\n';
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() << '\n';
result.status = InvalidOutputFileAccess;
return result;
}
@@ -1364,9 +1336,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
};
std::vector<CoreEmissionResult> jobResults(jobs.size());
mlir::parallelFor(moduleOp.getContext(), 0, jobs.size(), [&](size_t index) {
jobResults[index] = emitJob(jobs[index]);
});
mlir::parallelFor(
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
if (jobResults[jobIndex].status != CompilerSuccess)