diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index 880c37b..e849345 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef& module, } if (pimEmissionTarget >= EmitPim) { - pm.addPass(createMergeComputeNodePass()); + pm.addPass(createMergeComputeNodesPass()); pm.addPass(createSpatialToPimPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Spatial lowered to Pim")); @@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef& module, } if (pimEmissionTarget >= EmitPimCodegen) { - pm.addPass(createPimConstantFoldingPass()); - pm.addPass(createMessagePass("Pim constants folded")); - pm.addPass(createPimMaterializeConstantsPass()); + pm.addPass(createPimHostConstantFoldingPass()); + pm.addPass(createMessagePass("Pim host constants folded")); + pm.addPass(createPimMaterializeHostConstantsPass()); pm.addPass(createPimVerificationPass()); pm.addPass(createMessagePass("Pim verified")); pm.addPass(createEmitPimJsonPass()); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index d64615c..1c47aba 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -19,7 +19,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index dd146ca..5d4af51 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -7,6 +7,7 @@ #include #include "Common.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace llvm; using namespace mlir; @@ -85,6 +86,25 @@ IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value chan return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName); } +bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) { + auto channelNewOp = channel.getDefiningOp(); + return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName); +} + +bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) { + auto channelNewOp = channel.getDefiningOp(); + return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName); +} + +mlir::Value createPimReceiveFromSpatialChannel( + PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) { + mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp()); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output); + auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel); + return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) + .getOutput(); +} + Operation* getEarliestUserWithinBlock(mlir::Value value) { auto users = value.getUsers(); diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index 99819d9..d8e9d9d 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -34,6 +34,13 @@ mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel); +bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel); + +bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel); + +mlir::Value createPimReceiveFromSpatialChannel( + mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel); + template size_t rangeLength(const mlir::iterator_range range) { return std::distance(range.begin(), range.end()); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 7bcb91d..0e323cd 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td" #endif // OP_BASE +def HasSpatialChannelSourceCoreIdAttr: Constraint< + CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">, + "spatial channel has precomputed source core id">; + +def HasSpatialChannelTargetCoreIdAttr: Constraint< + CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">, + "spatial channel has precomputed target core id">; + +def createPimReceiveFromSpatialChannelValue: NativeCodeCall< + "onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">; + def onnxToPimTranspose : Pat< (ONNXTransposeOp:$srcOpRes $data, $perms), (PimTransposeOp $data, $perms, @@ -73,15 +84,14 @@ def spatChannelSendToPimSend : Pat< (SpatChannelSendOp $channel, $input), (PimSendOp $input, (NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input), - (NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)) + (NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)), + [(HasSpatialChannelTargetCoreIdAttr $channel)] >; def spatChannelReceiveToPimReceive : Pat< (SpatChannelReceiveOp:$srcOpRes $channel), - (PimReceiveOp - (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes), - (NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes), - (NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel)) + (createPimReceiveFromSpatialChannelValue $srcOpRes, $channel), + [(HasSpatialChannelSourceCoreIdAttr $channel)] >; #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index f9beb4c..867b207 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -641,6 +641,9 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) { funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) { markOpToRemove(channelNewOp); + if (channelNewOp->use_empty()) + return; + spatial::SpatChannelSendOp sendOp; spatial::SpatChannelReceiveOp receiveOp; spatial::SpatChannelBroadcastSendOp broadcastSendOp; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index eae40c2..e217cd6 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface } }; +struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto receiveOp = cast(op); + + auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + replaceOpWithNewBufferizedOp(rewriter, + op, + outputBufferOpt->getType(), + *outputBufferOpt, + receiveOp.getSizeAttr(), + receiveOp.getSourceCoreIdAttr()); + return success(); + } +}; + struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); @@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); PimTransposeOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 3efa16b..286abd7 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -3,10 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td) add_pim_library(SpatialOps SpatialOps.cpp - Transforms/MergeComputeNode/MergeComputeNodePass.cpp - DCPGraph/Graph.cpp - DCPGraph/Task.cpp - DCPGraph/DCPAnalysis.cpp + Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp + Transforms/MergeComputeNodes/DCPGraph/Graph.cpp + Transforms/MergeComputeNodes/DCPGraph/Task.cpp + Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp similarity index 98% rename from src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index 7a0ed07..ebc4156 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -8,7 +8,6 @@ #include -#include "../SpatialOps.hpp" #include "DCPAnalysis.hpp" #include "Graph.hpp" #include "src/Support/TypeUtilities.hpp" diff --git a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp similarity index 92% rename from src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index eee0dd8..f7426cc 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -6,7 +6,7 @@ #include -#include "../SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" struct DCPAnalysisResult { std::vector dominanceOrderCompute; diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp similarity index 99% rename from src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp index 4f4e1c0..2210705 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp @@ -7,11 +7,11 @@ #include #include -#include "../../../Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "DCPAnalysis.hpp" #include "Graph.hpp" #include "Task.hpp" -#include "Uniqueworklist.hpp" +#include "UniqueWorklist.hpp" #include "Utils.hpp" std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) { diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Graph.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp similarity index 100% rename from src/PIM/Dialect/Spatial/DCPGraph/Graph.hpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.cpp similarity index 97% rename from src/PIM/Dialect/Spatial/DCPGraph/Task.cpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.cpp index d8f0114..8e46388 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.cpp @@ -2,7 +2,7 @@ #include "Graph.hpp" #include "Task.hpp" -#include "Uniqueworklist.hpp" +#include "UniqueWorklist.hpp" std::optional TaskDCP::addChild(TaskDCP* child, Weight_t weight) { std::optional oldEdge = std::nullopt; diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp similarity index 97% rename from src/PIM/Dialect/Spatial/DCPGraph/Task.hpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp index 2368525..90f5438 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp @@ -5,7 +5,7 @@ #include #include -#include "../SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "Utils.hpp" std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/UniqueWorklist.hpp similarity index 100% rename from src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/UniqueWorklist.hpp diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp similarity index 96% rename from src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp index 276e769..29ab851 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp @@ -9,7 +9,7 @@ #include #include -#include "../SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Support/TypeUtilities.hpp" using CPU = int; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp similarity index 96% rename from src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp rename to src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 5909e5d..83dce8a 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -20,7 +20,7 @@ #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" +#include "DCPGraph/DCPAnalysis.hpp" using namespace mlir; @@ -81,7 +81,7 @@ public: ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); } }; -struct MergeComputeNodePass : PassWrapper> { +struct MergeComputeNodesPass : PassWrapper> { private: DenseMap newComputeNodeResults; @@ -89,11 +89,11 @@ private: DenseMap cputToNewComputeMap; public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass) + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass) - StringRef getArgument() const override { return "pim-merge-node-pass"; } + StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; } StringRef getDescription() const override { - return "Merge Spatial-Weighted-Compute-Node in order to reduce the total " + return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total " "execution time"; } @@ -346,6 +346,6 @@ private: } // namespace -std::unique_ptr createMergeComputeNodePass() { return std::make_unique(); } +std::unique_ptr createMergeComputeNodesPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/CMakeLists.txt b/src/PIM/Pass/CMakeLists.txt index 4c4507b..e5291c0 100644 --- a/src/PIM/Pass/CMakeLists.txt +++ b/src/PIM/Pass/CMakeLists.txt @@ -1,13 +1,13 @@ add_pim_library(OMPimPasses CountInstructionPass.cpp MessagePass.cpp - Pim/ConstantFolding/Common.cpp - Pim/ConstantFolding/Patterns/Constant.cpp - Pim/ConstantFolding/ConstantFoldingPass.cpp - Pim/ConstantFolding/Patterns/Subview.cpp - Pim/MaterializeConstantsPass.cpp - Pim/VerificationPass.cpp - Pim/EmitPimJsonPass.cpp + PimCodegen/HostConstantFolding/Common.cpp + PimCodegen/HostConstantFolding/Patterns/Constant.cpp + PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp + PimCodegen/HostConstantFolding/Patterns/Subview.cpp + PimCodegen/MaterializeHostConstantsPass.cpp + PimCodegen/VerificationPass.cpp + PimCodegen/EmitPimJsonPass.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index 4417f50..4734d52 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -15,11 +15,11 @@ std::unique_ptr createSpatialToPimPass(); std::unique_ptr createPimBufferizationPass(); -std::unique_ptr createMergeComputeNodePass(); +std::unique_ptr createMergeComputeNodesPass(); -std::unique_ptr createPimConstantFoldingPass(); +std::unique_ptr createPimHostConstantFoldingPass(); -std::unique_ptr createPimMaterializeConstantsPass(); +std::unique_ptr createPimMaterializeHostConstantsPass(); std::unique_ptr createPimVerificationPass(); diff --git a/src/PIM/Pass/Pim/EmitPimJsonPass.cpp b/src/PIM/Pass/PimCodegen/EmitPimJsonPass.cpp similarity index 100% rename from src/PIM/Pass/Pim/EmitPimJsonPass.cpp rename to src/PIM/Pass/PimCodegen/EmitPimJsonPass.cpp diff --git a/src/PIM/Pass/Pim/ConstantFolding/Common.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp similarity index 100% rename from src/PIM/Pass/Pim/ConstantFolding/Common.cpp rename to src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp diff --git a/src/PIM/Pass/Pim/ConstantFolding/Common.hpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp similarity index 100% rename from src/PIM/Pass/Pim/ConstantFolding/Common.hpp rename to src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp diff --git a/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp similarity index 78% rename from src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp rename to src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp index 76bd097..469c09c 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp @@ -11,10 +11,10 @@ using namespace mlir; namespace onnx_mlir { namespace { -struct ConstantFoldingPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass) +struct HostConstantFoldingPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass) - StringRef getArgument() const override { return "pim-constant-folding-pass"; } + StringRef getArgument() const override { return "pim-host-constant-folding-pass"; } StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } LogicalResult initialize(MLIRContext* context) override { @@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper createPimConstantFoldingPass() { return std::make_unique(); } +std::unique_ptr createPimHostConstantFoldingPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns.hpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns.hpp similarity index 100% rename from src/PIM/Pass/Pim/ConstantFolding/Patterns.hpp rename to src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns.hpp diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp similarity index 100% rename from src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp rename to src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp similarity index 100% rename from src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp rename to src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp diff --git a/src/PIM/Pass/Pim/MaterializeConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp similarity index 93% rename from src/PIM/Pass/Pim/MaterializeConstantsPass.cpp rename to src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index fdb9904..dbf8fd1 100644 --- a/src/PIM/Pass/Pim/MaterializeConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } -struct MaterializeConstantsPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass) +struct MaterializeHostConstantsPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) - StringRef getArgument() const override { return "materialize-pim-constants"; } + StringRef getArgument() const override { return "materialize-pim-host-constants"; } StringRef getDescription() const override { return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops"; } @@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper createPimMaterializeConstantsPass() { return std::make_unique(); } +std::unique_ptr createPimMaterializeHostConstantsPass() { + return std::make_unique(); +} } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp similarity index 100% rename from src/PIM/Pass/Pim/VerificationPass.cpp rename to src/PIM/Pass/PimCodegen/VerificationPass.cpp diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 6b9ce54..4b204cf 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -75,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const { registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPimPass); registerPass(createPimBufferizationPass); - registerPass(createMergeComputeNodePass); - registerPass(createPimConstantFoldingPass); - registerPass(createPimMaterializeConstantsPass); + registerPass(createMergeComputeNodesPass); + registerPass(createPimHostConstantFoldingPass); + registerPass(createPimMaterializeHostConstantsPass); registerPass(createPimVerificationPass); registerPass(createEmitPimJsonPass); } diff --git a/src/PIM/TODO.md b/src/PIM/TODO.md index b194d33..267cd4e 100644 --- a/src/PIM/TODO.md +++ b/src/PIM/TODO.md @@ -1,5 +1,3 @@ - -Rimuovere la logica di bufferizazione da spatial, Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode) AnalisiDCP