refactor Pim passes layout
This commit is contained in:
@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
|
||||
add_pim_library(OMPimCompilerUtils
|
||||
PimCompilerUtils.cpp
|
||||
PimCodeGen.cpp
|
||||
../Pass/EmitPimJsonPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerUtils"
|
||||
@@ -45,10 +45,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createPimMaterializeConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMaterializeConstantsPass());
|
||||
pm.addPass(createVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
add_pim_library(OMPimPasses
|
||||
CountInstructionPass.cpp
|
||||
MessagePass.cpp
|
||||
PimConstantFolding/Common.cpp
|
||||
PimConstantFolding/Patterns/Constant.cpp
|
||||
PimConstantFolding/PimConstantFoldingPass.cpp
|
||||
PimConstantFolding/Patterns/Subview.cpp
|
||||
PimMaterializeConstantsPass.cpp
|
||||
PimVerificationPass.cpp
|
||||
Pim/ConstantFolding/Common.cpp
|
||||
Pim/ConstantFolding/Patterns/Constant.cpp
|
||||
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
||||
Pim/ConstantFolding/Patterns/Subview.cpp
|
||||
Pim/MaterializeConstantsPass.cpp
|
||||
Pim/VerificationPass.cpp
|
||||
Pim/EmitPimJsonPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
std::unique_ptr<mlir::Pass> createVerificationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
|
||||
@@ -12,8 +12,8 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
@@ -48,6 +48,6 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -31,9 +31,8 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct PimMaterializeConstantsPass
|
||||
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
|
||||
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
@@ -128,8 +127,8 @@ struct PimMaterializeConstantsPass
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
|
||||
return std::make_unique<PimMaterializeConstantsPass>();
|
||||
std::unique_ptr<Pass> createMaterializeConstantsPass() {
|
||||
return std::make_unique<MaterializeConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -43,16 +43,16 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
return false;
|
||||
}
|
||||
|
||||
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
|
||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||
}
|
||||
|
||||
PimVerificationPass() {}
|
||||
PimVerificationPass(const PimVerificationPass& pass) {}
|
||||
VerificationPass() {}
|
||||
VerificationPass(const VerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -189,6 +189,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
|
||||
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -17,7 +17,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimAccelerator"
|
||||
@@ -73,9 +73,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
registerPass(createSpatialToGraphvizPass);
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createBufferizePimPass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimMaterializeConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createConstantFoldingPass);
|
||||
registerPass(createMaterializeConstantsPass);
|
||||
registerPass(createVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user