fix output paths
add gemm test
This commit is contained in:
@@ -1,17 +1,40 @@
|
|||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
llvm::FailureOr<Operation *> getOtherEndOfChannel(
|
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
|
||||||
Operation *op, bool opIsReceive, RewriterBase &rewriter) {
|
|
||||||
|
void createDirectory(const std::string& directory) {
|
||||||
|
std::error_code errorCode;
|
||||||
|
std::filesystem::create_directories(directory, errorCode);
|
||||||
|
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||||
|
std::string dialectsDir = getOutputDir() + "/dialects";
|
||||||
|
createDirectory(dialectsDir);
|
||||||
|
|
||||||
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
|
llvm::raw_os_ostream os(file);
|
||||||
|
os << *moduleOp;
|
||||||
|
os.flush();
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
if (!channelNewOp) {
|
if (!channelNewOp) {
|
||||||
op->emitError(
|
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||||
"User of Channel must have the first operand created by ChannelNewOp.");
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
// channelNewOp should have two users: `op` and a
|
// channelNewOp should have two users: `op` and a
|
||||||
@@ -38,9 +61,11 @@ llvm::FailureOr<Operation *> getOtherEndOfChannel(
|
|||||||
Operation* notOpUser;
|
Operation* notOpUser;
|
||||||
if (firstUser == op) {
|
if (firstUser == op) {
|
||||||
notOpUser = secondUser;
|
notOpUser = secondUser;
|
||||||
} else if (secondUser == op) {
|
}
|
||||||
|
else if (secondUser == op) {
|
||||||
notOpUser = firstUser;
|
notOpUser = firstUser;
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
"and one of them must be me, but"
|
"and one of them must be me, but"
|
||||||
"none of them is actually me.");
|
"none of them is actually me.");
|
||||||
@@ -54,7 +79,8 @@ llvm::FailureOr<Operation *> getOtherEndOfChannel(
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
return notOpUser;
|
return notOpUser;
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||||
"me, the other is not a ChannelReceiveOp.");
|
"me, the other is not a ChannelReceiveOp.");
|
||||||
|
|||||||
@@ -3,14 +3,22 @@
|
|||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME =
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
"pim.constant.should_allocate";
|
|
||||||
|
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Operation *> getOtherEndOfChannel(
|
std::string getOutputDir();
|
||||||
mlir::Operation *op, bool opIsReceive, mlir::RewriterBase &rewriter);
|
|
||||||
|
void createDirectory(const std::string &directory);
|
||||||
|
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string &name);
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::Operation*>
|
||||||
|
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -433,14 +433,7 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
return std::to_string(size) + " Bytes";
|
return std::to_string(size) + " Bytes";
|
||||||
}
|
}
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& moduleOpRef, std::string& outputDirPath) {
|
OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||||
ModuleOp moduleOp = moduleOpRef.get();
|
|
||||||
|
|
||||||
if (pimEmissionTarget != EmitPimCodegen) {
|
|
||||||
moduleOp.dump();
|
|
||||||
return CompilerSuccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!outputDirPath.empty()) {
|
if (!outputDirPath.empty()) {
|
||||||
if (auto error = llvm::sys::fs::create_directory(outputDirPath)) {
|
if (auto error = llvm::sys::fs::create_directory(outputDirPath)) {
|
||||||
llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
|
llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
|
||||||
@@ -696,8 +689,6 @@ OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& m
|
|||||||
jsonOS << llvm::json::Value(std::move(configJson)) << '\n';
|
jsonOS << llvm::json::Value(std::move(configJson)) << '\n';
|
||||||
jsonOS.close();
|
jsonOS.close();
|
||||||
|
|
||||||
showCompilePhase("Code generated into " + configPath);
|
|
||||||
|
|
||||||
return CompilerSuccess;
|
return CompilerSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,6 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
llvm::cl::opt<std::string> pimOutputDir("pim-output-dir",
|
|
||||||
llvm::cl::desc("Directory where pim json code will be emitted"),
|
|
||||||
llvm::cl::init("pim"),
|
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
|
||||||
|
|
||||||
llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||||
llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"),
|
llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"),
|
||||||
llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")),
|
llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")),
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ typedef enum {
|
|||||||
} PimEmissionTargetType;
|
} PimEmissionTargetType;
|
||||||
|
|
||||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||||
extern llvm::cl::opt<std::string> pimOutputDir;
|
|
||||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
|
|||||||
@@ -1,18 +1,11 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
#include "src/Compiler/CompilerPasses.hpp"
|
#include "src/Compiler/CompilerPasses.hpp"
|
||||||
|
#include "src/Compiler/CompilerUtils.hpp"
|
||||||
#include "llvm/Support/JSON.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cstddef>
|
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerUtils"
|
#define DEBUG_TYPE "PimCompilerUtils"
|
||||||
|
|
||||||
@@ -37,19 +30,25 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
if (pimEmissionTarget >= EmitSpatial) {
|
if (pimEmissionTarget >= EmitSpatial) {
|
||||||
pm.addPass(createONNXToSpatialPass());
|
pm.addPass(createONNXToSpatialPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("ONNX lowered to SPATIAL"));
|
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPim) {
|
if (pimEmissionTarget >= EmitPim) {
|
||||||
pm.addPass(createSpatialToPIMPass());
|
pm.addPass(createSpatialToPIMPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("SPATIAL lowered to PIM"));
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||||
pm.addPass(createBufferizePimPass());
|
pm.addPass(createBufferizePimPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("PIM bufferized"));
|
pm.addPass(createMessagePass("Pim bufferized"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
|
pm.addPass(createEmitPimJsonPass());
|
||||||
|
// pm.addPass(createCountInstructionPass());
|
||||||
|
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ void addPassesPim(mlir::OwningOpRef<mlir::ModuleOp>& module,
|
|||||||
EmissionTargetType& emissionTarget,
|
EmissionTargetType& emissionTarget,
|
||||||
std::string outputNameNoExt);
|
std::string outputNameNoExt);
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const mlir::OwningOpRef<mlir::ModuleOp>& moduleOpRef,
|
OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||||
std::string& outputDirName);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "Common/PIMCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||||
#include "ONNXToSpatialPass.hpp"
|
#include "ONNXToSpatialPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
@@ -27,7 +28,7 @@ namespace spatial {
|
|||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
|
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
|
||||||
|
|
||||||
ModuleOp module = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
RewritePatternSet mergeActivationPatterns(ctx);
|
RewritePatternSet mergeActivationPatterns(ctx);
|
||||||
@@ -38,11 +39,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||||
|
|
||||||
if (failed(applyPatternsAndFoldGreedily(module, std::move(mergeActivationPatterns))))
|
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||||
|
|
||||||
IRRewriter rewriter(module);
|
IRRewriter rewriter(moduleOp);
|
||||||
func::FuncOp funcOp = *module.getOps<func::FuncOp>().begin();
|
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
|
||||||
if (annotateReplication(funcOp, rewriter).failed()) {
|
if (annotateReplication(funcOp, rewriter).failed()) {
|
||||||
llvm::dbgs() << "Failed during annotation for replication analysis\n";
|
llvm::dbgs() << "Failed during annotation for replication analysis\n";
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@@ -78,7 +79,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||||
populateReduceMeanConversionPattern(patterns, ctx);
|
populateReduceMeanConversionPattern(patterns, ctx);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -101,19 +102,13 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
|
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
|
||||||
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
|
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
|
||||||
|
|
||||||
if (failed(applyPatternsAndFoldGreedily(module, std::move(removeUnusedHelperOpsPatterns))))
|
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
|
||||||
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
|
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
|
||||||
|
|
||||||
annotateWeightsConstants(funcOp);
|
annotateWeightsConstants(funcOp);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
|
dumpModule(moduleOp, "spatial");
|
||||||
std::filesystem::create_directory(outputDir);
|
|
||||||
std::fstream file(outputDir + "/spatial.mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
os << *module;
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@@ -75,13 +74,7 @@ void SpatialToPIMPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
|
dumpModule(moduleOp, "pim");
|
||||||
std::filesystem::create_directory(outputDir);
|
|
||||||
std::fstream file(outputDir + "/pim.mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
os << *moduleOp;
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include <filesystem>
|
#include "Common/PIMCommon.hpp"
|
||||||
|
|
||||||
#include "Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "Compiler/PimCompilerUtils.hpp"
|
#include "Compiler/PimCompilerUtils.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -19,20 +17,13 @@ struct EmitPimJsonPass : PassWrapper<EmitPimJsonPass, OperationPass<ModuleOp>> {
|
|||||||
EmitPimJsonPass() {}
|
EmitPimJsonPass() {}
|
||||||
EmitPimJsonPass(const EmitPimJsonPass& pass) {}
|
EmitPimJsonPass(const EmitPimJsonPass& pass) {}
|
||||||
|
|
||||||
void runOnOperation() final {
|
void runOnOperation() override {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
|
|
||||||
std::filesystem::path pimDir(pimOutputDir.data());
|
std::string pimDir = getOutputDir() + "/pim";
|
||||||
|
createDirectory(pimDir);
|
||||||
|
|
||||||
std::error_code error_code;
|
int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimDir);
|
||||||
std::filesystem::create_directories(pimDir, error_code);
|
|
||||||
if (error_code) {
|
|
||||||
moduleOp.emitError("Failed to create PIM output directory: " + error_code.message());
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimOutputDir);
|
|
||||||
if (compiler_error_code != CompilerSuccess)
|
if (compiler_error_code != CompilerSuccess)
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
#include "src/Compiler/CompilerUtils.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -7,21 +8,15 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct MessagePass : public PassWrapper<MessagePass, OperationPass<ModuleOp>> {
|
struct MessagePass : PassWrapper<MessagePass, OperationPass<ModuleOp>> {
|
||||||
|
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "message-pass"; }
|
StringRef getArgument() const override { return "message-pass"; }
|
||||||
|
StringRef getDescription() const override { return "Print compilation status"; }
|
||||||
|
|
||||||
StringRef getDescription() const override {
|
MessagePass(std::string message)
|
||||||
return "Lower ONNX ops to Spatial ops.";
|
: message(message) {}
|
||||||
}
|
MessagePass(const MessagePass& pass) {}
|
||||||
|
|
||||||
// Make sure that we have a valid default constructor and copy
|
|
||||||
// constructor to make sure that the options are initialized properly.
|
|
||||||
MessagePass(std::string message) : message(message) {}
|
|
||||||
MessagePass(const MessagePass &pass)
|
|
||||||
: PassWrapper<MessagePass, OperationPass<ModuleOp>>() {}
|
|
||||||
void runOnOperation() final { showCompilePhase(message); }
|
void runOnOperation() final { showCompilePhase(message); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -30,8 +25,6 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createMessagePass(std::string message) {
|
std::unique_ptr<Pass> createMessagePass(std::string message) { return std::make_unique<MessagePass>(message); }
|
||||||
return std::make_unique<MessagePass>(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -7,9 +7,7 @@
|
|||||||
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <filesystem>
|
#include "Common/PIMCommon.hpp"
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "PimBufferizationPass.hpp"
|
#include "PimBufferizationPass.hpp"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
@@ -59,14 +57,7 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
ModuleOp module = getOperation();
|
dumpModule(moduleOp, "pim_buf");
|
||||||
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
|
|
||||||
std::filesystem::create_directory(outputDir);
|
|
||||||
std::fstream file(outputDir + "/pim_buf.mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
os << *module;
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
|
|||||||
Binary file not shown.
Reference in New Issue
Block a user