Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
60 lines
1.8 KiB
C++
60 lines
1.8 KiB
C++
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
#include "src/Compiler/CompilerPasses.hpp"
|
|
|
|
#define DEBUG_TYPE "PimCompilerUtils"
|
|
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|
PassManager& pm,
|
|
EmissionTargetType& emissionTarget,
|
|
std::string outputNameNoExt) {
|
|
|
|
if (pimOnlyCodegen) {
|
|
// Skip all the lowering passes and directly generate code for PIM.
|
|
return;
|
|
}
|
|
|
|
if (emissionTarget >= EmitONNXIR)
|
|
addONNXToMLIRPasses(pm, /*target CPU*/ false);
|
|
|
|
if (pimEmissionTarget >= EmitSpatial) {
|
|
pm.addPass(createONNXToSpatialPass());
|
|
// pm.addPass(createCountInstructionPass());
|
|
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
|
}
|
|
|
|
if (pimEmissionTarget >= EmitPim) {
|
|
pm.addPass(createSpatialToPimPass());
|
|
// pm.addPass(createCountInstructionPass());
|
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
|
}
|
|
|
|
if (pimEmissionTarget >= EmitPimBufferized) {
|
|
pm.addPass(createPimBufferizationPass());
|
|
// pm.addPass(createCountInstructionPass());
|
|
pm.addPass(createMessagePass("Pim bufferized"));
|
|
}
|
|
|
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
|
pm.addPass(createPimConstantFoldingPass());
|
|
pm.addPass(createMessagePass("Pim constants folded"));
|
|
pm.addPass(createPimMaterializeConstantsPass());
|
|
pm.addPass(createPimVerificationPass());
|
|
pm.addPass(createMessagePass("Pim verified"));
|
|
pm.addPass(createEmitPimJsonPass());
|
|
// pm.addPass(createCountInstructionPass());
|
|
pm.addPass(createMessagePass("Pim json code emitted"));
|
|
}
|
|
}
|
|
|
|
} // namespace onnx_mlir
|