refactor Pim passes layout
This commit is contained in:
@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
|
|||||||
add_pim_library(OMPimCompilerUtils
|
add_pim_library(OMPimCompilerUtils
|
||||||
PimCompilerUtils.cpp
|
PimCompilerUtils.cpp
|
||||||
PimCodeGen.cpp
|
PimCodeGen.cpp
|
||||||
../Pass/EmitPimJsonPass.cpp
|
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerPasses.hpp"
|
#include "src/Compiler/CompilerPasses.hpp"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerUtils"
|
#define DEBUG_TYPE "PimCompilerUtils"
|
||||||
@@ -45,10 +45,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createPimConstantFoldingPass());
|
pm.addPass(createConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim constants folded"));
|
pm.addPass(createMessagePass("Pim constants folded"));
|
||||||
pm.addPass(createPimMaterializeConstantsPass());
|
pm.addPass(createMaterializeConstantsPass());
|
||||||
pm.addPass(createPimVerificationPass());
|
pm.addPass(createVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim verified"));
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
add_pim_library(OMPimPasses
|
add_pim_library(OMPimPasses
|
||||||
CountInstructionPass.cpp
|
CountInstructionPass.cpp
|
||||||
MessagePass.cpp
|
MessagePass.cpp
|
||||||
PimConstantFolding/Common.cpp
|
Pim/ConstantFolding/Common.cpp
|
||||||
PimConstantFolding/Patterns/Constant.cpp
|
Pim/ConstantFolding/Patterns/Constant.cpp
|
||||||
PimConstantFolding/PimConstantFoldingPass.cpp
|
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
||||||
PimConstantFolding/Patterns/Subview.cpp
|
Pim/ConstantFolding/Patterns/Subview.cpp
|
||||||
PimMaterializeConstantsPass.cpp
|
Pim/MaterializeConstantsPass.cpp
|
||||||
PimVerificationPass.cpp
|
Pim/VerificationPass.cpp
|
||||||
|
Pim/EmitPimJsonPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
std::unique_ptr<mlir::Pass> createVerificationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||||
|
|
||||||
@@ -12,8 +12,8 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||||
@@ -48,6 +48,6 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -31,9 +31,8 @@ static int64_t getValueSizeInBytes(Value value) {
|
|||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PimMaterializeConstantsPass
|
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||||
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
|
|
||||||
|
|
||||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
@@ -128,8 +127,8 @@ struct PimMaterializeConstantsPass
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
|
std::unique_ptr<Pass> createMaterializeConstantsPass() {
|
||||||
return std::make_unique<PimMaterializeConstantsPass>();
|
return std::make_unique<MaterializeConstantsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -43,16 +43,16 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
|
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "verify-pim-pass"; }
|
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||||
}
|
}
|
||||||
|
|
||||||
PimVerificationPass() {}
|
VerificationPass() {}
|
||||||
PimVerificationPass(const PimVerificationPass& pass) {}
|
VerificationPass(const VerificationPass& pass) {}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
@@ -189,6 +189,6 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
|
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimAccelerator"
|
#define DEBUG_TYPE "PimAccelerator"
|
||||||
@@ -73,9 +73,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createBufferizePimPass);
|
registerPass(createBufferizePimPass);
|
||||||
registerPass(createPimConstantFoldingPass);
|
registerPass(createConstantFoldingPass);
|
||||||
registerPass(createPimMaterializeConstantsPass);
|
registerPass(createMaterializeConstantsPass);
|
||||||
registerPass(createPimVerificationPass);
|
registerPass(createVerificationPass);
|
||||||
registerPass(createEmitPimJsonPass);
|
registerPass(createEmitPimJsonPass);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user