fix missed failing tests for channels
moderate refactor
This commit is contained in:
@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPim) {
|
if (pimEmissionTarget >= EmitPim) {
|
||||||
pm.addPass(createMergeComputeNodePass());
|
pm.addPass(createMergeComputeNodesPass());
|
||||||
pm.addPass(createSpatialToPimPass());
|
pm.addPass(createSpatialToPimPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||||
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createPimConstantFoldingPass());
|
pm.addPass(createPimHostConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim constants folded"));
|
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||||
pm.addPass(createPimMaterializeConstantsPass());
|
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||||
pm.addPass(createPimVerificationPass());
|
pm.addPass(createPimVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim verified"));
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -85,6 +86,25 @@ IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value chan
|
|||||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value createPimReceiveFromSpatialChannel(
|
||||||
|
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||||
|
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||||
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||||
|
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||||
|
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||||
auto users = value.getUsers();
|
auto users = value.getUsers();
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir
|
|||||||
|
|
||||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||||
|
|
||||||
|
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||||
|
|
||||||
|
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||||
|
|
||||||
|
mlir::Value createPimReceiveFromSpatialChannel(
|
||||||
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||||
return std::distance(range.begin(), range.end());
|
return std::distance(range.begin(), range.end());
|
||||||
|
|||||||
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||||
|
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||||
|
"spatial channel has precomputed source core id">;
|
||||||
|
|
||||||
|
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||||
|
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||||
|
"spatial channel has precomputed target core id">;
|
||||||
|
|
||||||
|
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||||
|
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||||
|
|
||||||
def onnxToPimTranspose : Pat<
|
def onnxToPimTranspose : Pat<
|
||||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||||
(PimTransposeOp $data, $perms,
|
(PimTransposeOp $data, $perms,
|
||||||
@@ -73,15 +84,14 @@ def spatChannelSendToPimSend : Pat<
|
|||||||
(SpatChannelSendOp $channel, $input),
|
(SpatChannelSendOp $channel, $input),
|
||||||
(PimSendOp $input,
|
(PimSendOp $input,
|
||||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel))
|
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||||
|
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatChannelReceiveToPimReceive : Pat<
|
def spatChannelReceiveToPimReceive : Pat<
|
||||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||||
(PimReceiveOp
|
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes),
|
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
|
|
||||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
|
|
||||||
>;
|
>;
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
#endif // SPATIAL_TO_PIM
|
||||||
|
|||||||
@@ -641,6 +641,9 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
|||||||
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||||
markOpToRemove(channelNewOp);
|
markOpToRemove(channelNewOp);
|
||||||
|
|
||||||
|
if (channelNewOp->use_empty())
|
||||||
|
return;
|
||||||
|
|
||||||
spatial::SpatChannelSendOp sendOp;
|
spatial::SpatChannelSendOp sendOp;
|
||||||
spatial::SpatChannelReceiveOp receiveOp;
|
spatial::SpatChannelReceiveOp receiveOp;
|
||||||
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||||
|
|||||||
@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto receiveOp = cast<PimReceiveOp>(op);
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||||
|
op,
|
||||||
|
outputBufferOpt->getType(),
|
||||||
|
*outputBufferOpt,
|
||||||
|
receiveOp.getSizeAttr(),
|
||||||
|
receiveOp.getSourceCoreIdAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
|
|
||||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
|
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
|||||||
|
|
||||||
add_pim_library(SpatialOps
|
add_pim_library(SpatialOps
|
||||||
SpatialOps.cpp
|
SpatialOps.cpp
|
||||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
DCPGraph/Graph.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||||
DCPGraph/Task.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
|
||||||
DCPGraph/DCPAnalysis.cpp
|
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
struct DCPAnalysisResult {
|
struct DCPAnalysisResult {
|
||||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||||
@@ -7,11 +7,11 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../../Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
#include "Task.hpp"
|
#include "Task.hpp"
|
||||||
#include "Uniqueworklist.hpp"
|
#include "UniqueWorklist.hpp"
|
||||||
#include "Utils.hpp"
|
#include "Utils.hpp"
|
||||||
|
|
||||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
#include "Task.hpp"
|
#include "Task.hpp"
|
||||||
#include "Uniqueworklist.hpp"
|
#include "UniqueWorklist.hpp"
|
||||||
|
|
||||||
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
||||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "Utils.hpp"
|
#include "Utils.hpp"
|
||||||
|
|
||||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
using CPU = int;
|
using CPU = int;
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
#include "DCPGraph/DCPAnalysis.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ public:
|
|||||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
|
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||||
@@ -89,11 +89,11 @@ private:
|
|||||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-merge-node-pass"; }
|
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
|
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||||
"execution time";
|
"execution time";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -346,6 +346,6 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
|
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
add_pim_library(OMPimPasses
|
add_pim_library(OMPimPasses
|
||||||
CountInstructionPass.cpp
|
CountInstructionPass.cpp
|
||||||
MessagePass.cpp
|
MessagePass.cpp
|
||||||
Pim/ConstantFolding/Common.cpp
|
PimCodegen/HostConstantFolding/Common.cpp
|
||||||
Pim/ConstantFolding/Patterns/Constant.cpp
|
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
|
||||||
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
|
||||||
Pim/ConstantFolding/Patterns/Subview.cpp
|
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
|
||||||
Pim/MaterializeConstantsPass.cpp
|
PimCodegen/MaterializeHostConstantsPass.cpp
|
||||||
Pim/VerificationPass.cpp
|
PimCodegen/VerificationPass.cpp
|
||||||
Pim/EmitPimJsonPass.cpp
|
PimCodegen/EmitPimJsonPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
|
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||||
|
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
|
||||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||||
|
|
||||||
LogicalResult initialize(MLIRContext* context) override {
|
LogicalResult initialize(MLIRContext* context) override {
|
||||||
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
|
|||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
|
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||||
}
|
}
|
||||||
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||||
|
return std::make_unique<MaterializeHostConstantsPass>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -75,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createPimBufferizationPass);
|
registerPass(createPimBufferizationPass);
|
||||||
registerPass(createMergeComputeNodePass);
|
registerPass(createMergeComputeNodesPass);
|
||||||
registerPass(createPimConstantFoldingPass);
|
registerPass(createPimHostConstantFoldingPass);
|
||||||
registerPass(createPimMaterializeConstantsPass);
|
registerPass(createPimMaterializeHostConstantsPass);
|
||||||
registerPass(createPimVerificationPass);
|
registerPass(createPimVerificationPass);
|
||||||
registerPass(createEmitPimJsonPass);
|
registerPass(createEmitPimJsonPass);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
Rimuovere la logica di bufferizazione da spatial,
|
|
||||||
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
|
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
|
||||||
|
|
||||||
AnalisiDCP
|
AnalisiDCP
|
||||||
|
|||||||
Reference in New Issue
Block a user