Files
Raptor/src/PIM/Compiler/PimCompilerOptions.cpp
T
2026-06-24 15:52:07 +02:00

142 lines
7.2 KiB
C++

#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"),
llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")),
llvm::cl::values(clEnumVal(EmitPim, "Lower model to PIM IR")),
llvm::cl::values(clEnumVal(EmitPimBufferized, "Lower model to PIM IR and bufferize it")),
llvm::cl::values(clEnumVal(EmitPimCodegen, "Lower model to PIM IR and generate code for PIM")),
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType>
pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
"pim-memory-report",
llvm::cl::desc("Emit a human-readable PIM memory planning report"),
llvm::cl::values(clEnumValN(PimMemoryReportNone, "none", "Do not emit any PIM memory planning report")),
llvm::cl::values(
clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")),
llvm::cl::values(clEnumValN(PimMemoryReportFull, "full", "Emit the full detailed PIM memory planning report")),
llvm::cl::init(PimMemoryReportNone),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimConvLoweringType> pimConvLowering(
"pim-conv-lowering",
llvm::cl::desc("Convolution lowering strategy for PIM"),
llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")),
llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")),
llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")),
llvm::cl::values(
clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")),
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch,
"streamed-patch",
"Use streamed/chunked im2col rows without multi-position packing")),
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked,
"streamed-packed",
"Use streamed/chunked im2col rows with packed multi-position GEMM")),
llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled,
"output-channel-tiled",
"Force Conv lowering that relies on Gemm output-channel tiling")),
llvm::cl::values(
clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")),
llvm::cl::values(clEnumValN(PimConvLoweringTiled2D,
"tiled-2d",
"Force Conv lowering that relies on Gemm 2D K/C tiling")),
llvm::cl::init(PimConvLoweringAuto),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimDisableMemoryCoalescing("pim-disable-memory-coalescing",
llvm::cl::desc("Skip the PIM memory coalescing pass (developer diagnostic option)"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::desc("Use experimental implementation for convolution"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<uint64_t> pimConvIm2colMaxElements(
"pim-conv-im2col-max-elements",
llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"),
llvm::cl::init(1ull << 20),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<uint64_t> pimConvStreamChunkPositions(
"pim-conv-stream-chunk-positions",
llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"),
llvm::cl::init(1024),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimReportConvLowering("pim-report-conv-lowering",
llvm::cl::desc("Emit a bounded Conv lowering report"),
llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimDetectCommunicationDeadlock(
"pim-detect-communication-deadlock",
llvm::cl::desc("Expensively simulate the statically expanded PIM send/receive order at verification time and fail if a blocking communication deadlock is found"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder(
"pim-materialize-scalar-fanout-global-order",
llvm::cl::desc("Experimental expensive materializer mode: emit scalar-source fanout as globally ordered communication events instead of all-send fanout loops"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimTraceCommunicationMaterialization(
"pim-trace-communication-materialization",
llvm::cl::desc("Emit verbose materializer-time diagnostics and provenance attributes for every Spatial communication op"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false));
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
void verifyExplicitPimCoreCount() {
if (!hasExplicitPimCoreCount())
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
if (coresCount.getValue() <= 0)
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
}
} // namespace onnx_mlir