fix missed failing tests for channels
moderate refactor
This commit is contained in:
@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveOp>(op);
|
||||
|
||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
|
||||
@@ -3,10 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
||||
DCPGraph/Graph.cpp
|
||||
DCPGraph/Task.cpp
|
||||
DCPGraph/DCPAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
#include <iterator>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||
@@ -7,11 +7,11 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "Uniqueworklist.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "Graph.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "Uniqueworklist.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
|
||||
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
using CPU = int;
|
||||
@@ -20,7 +20,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -81,7 +81,7 @@ public:
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
};
|
||||
|
||||
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
|
||||
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
@@ -89,11 +89,11 @@ private:
|
||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-merge-node-pass"; }
|
||||
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
|
||||
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||
"execution time";
|
||||
}
|
||||
|
||||
@@ -346,6 +346,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
|
||||
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user