remove useless MaterializeHostConstantsPass.cpp and fix lowering before instead
Validate Operations / validate-operations (push) Has been cancelled

avoid spammy pim codegen diagnostics
This commit is contained in:
NiccoloN
2026-06-05 10:06:28 +02:00
parent 27410207c4
commit 1e9e61f5a9
20 changed files with 458 additions and 256 deletions
-1
View File
@@ -31,7 +31,6 @@ add_pim_library(OMPimCompilerUtils
OMPimBufferization
OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification
OMPimPasses
OMONNXToSpatial
+66 -9
View File
@@ -32,6 +32,7 @@
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/Diagnostics.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -996,12 +997,44 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
}
struct CoreEmissionResult {
static constexpr size_t kMaxStoredCodegenDiagnostics = 8;
struct DiagnosticRecord {
Operation* op = nullptr;
std::string message;
};
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
MemoryReportRow reportRow;
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
MemoryPlanArtifacts livenessArtifacts;
llvm::SmallVector<DiagnosticRecord, kMaxStoredCodegenDiagnostics> diagnostics;
size_t diagnosticCount = 0;
void recordDiagnostic(Operation* op, StringRef message) {
++diagnosticCount;
if (diagnostics.size() < kMaxStoredCodegenDiagnostics)
diagnostics.push_back({op, message.str()});
}
};
static StaticValueKnowledge seedCoreCodegenKnowledge(pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
return knowledge;
}
static StaticValueKnowledge seedCoreBatchCodegenKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
template <typename MapTy>
class ScopedMapBindings {
using KeyTy = typename MapTy::key_type;
@@ -1422,7 +1455,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
if (failed(weightView)) {
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
std::string message;
llvm::raw_string_ostream os(message);
os << "requires a statically resolvable dense global weight view during PIM codegen; weight="
<< vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
result.recordDiagnostic(vmmOp, os.str());
return failure();
}
if (weightView->shape.size() != 2) {
std::string message;
llvm::raw_string_ostream os(message);
os << "requires a rank-2 matrix weight view during PIM codegen; resolved shape=[";
llvm::interleaveComma(weightView->shape, os);
os << "] weight=" << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
result.recordDiagnostic(vmmOp, os.str());
return failure();
}
@@ -1463,13 +1509,13 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
StaticValueKnowledge knowledge = seedCoreCodegenKnowledge(coreOp);
int64_t processedOperations =
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, knowledge, coreOp.getOperation(), resolveWeightSlot);
if (processedOperations < 0) {
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
result.reportRow = deviceMemory.getReportRow();
result.usedWeights = std::move(usedWeights);
result.livenessArtifacts = deviceMemory.getLivenessArtifacts();
@@ -1480,10 +1526,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
for (unsigned lane : job.lanes) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
StaticValueKnowledge knowledge = seedCoreBatchCodegenKnowledge(coreBatchOp, lane);
deviceMemory.allocateCore(coreBatchOp, lane);
coreCodeGen.setBatchLane(lane);
@@ -1498,7 +1541,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
}
result.reportRow = deviceMemory.getReportRow();
@@ -1522,6 +1564,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
mlir::parallelFor(
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
pim::CappedDiagnosticReporter diagnostics;
Operation* summaryAnchor = nullptr;
for (const CoreEmissionResult& result : jobResults) {
if (!summaryAnchor && !result.diagnostics.empty())
summaryAnchor = result.diagnostics.front().op;
for (const CoreEmissionResult::DiagnosticRecord& diagnostic : result.diagnostics) {
diagnostics.report(diagnostic.op, [&](Operation* op) { op->emitError() << diagnostic.message; });
}
size_t unreportedCount = result.diagnosticCount - result.diagnostics.size();
diagnostics.noteFailures(static_cast<int64_t>(unreportedCount));
}
if (diagnostics.hasFailure())
diagnostics.emitSuppressedSummary(summaryAnchor ? summaryAnchor : moduleOp.getOperation(),
"PIM codegen diagnostic(s)");
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
if (jobResults[jobIndex].status != CompilerSuccess)
return jobResults[jobIndex].status;
-1
View File
@@ -46,7 +46,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimMemoryCoalescingPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));