fix missed failing tests for channels

moderate refactor
This commit is contained in:
NiccoloN
2026-04-14 12:26:41 +02:00
parent 30ee9640d4
commit eade488d13
30 changed files with 115 additions and 50 deletions

View File

@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createMergeComputeNodePass()); pm.addPass(createMergeComputeNodesPass());
pm.addPass(createSpatialToPimPass()); pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim")); pm.addPass(createMessagePass("Spatial lowered to Pim"));
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass()); pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded")); pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeConstantsPass()); pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimVerificationPass()); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified")); pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());

View File

@@ -19,7 +19,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"

View File

@@ -7,6 +7,7 @@
#include <cstddef> #include <cstddef>
#include "Common.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
@@ -85,6 +86,25 @@ IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value chan
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName); return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
} }
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
}
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
}
Operation* getEarliestUserWithinBlock(mlir::Value value) { Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers(); auto users = value.getUsers();

View File

@@ -34,6 +34,13 @@ mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel); mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
mlir::Value createPimReceiveFromSpatialChannel(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
template <class T> template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) { size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end()); return std::distance(range.begin(), range.end());

View File

@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def HasSpatialChannelSourceCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
"spatial channel has precomputed source core id">;
def HasSpatialChannelTargetCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
"spatial channel has precomputed target core id">;
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
def onnxToPimTranspose : Pat< def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms), (ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms, (PimTransposeOp $data, $perms,
@@ -73,15 +84,14 @@ def spatChannelSendToPimSend : Pat<
(SpatChannelSendOp $channel, $input), (SpatChannelSendOp $channel, $input),
(PimSendOp $input, (PimSendOp $input,
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input), (NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)) (NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
[(HasSpatialChannelTargetCoreIdAttr $channel)]
>; >;
def spatChannelReceiveToPimReceive : Pat< def spatChannelReceiveToPimReceive : Pat<
(SpatChannelReceiveOp:$srcOpRes $channel), (SpatChannelReceiveOp:$srcOpRes $channel),
(PimReceiveOp (createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes), [(HasSpatialChannelSourceCoreIdAttr $channel)]
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
>; >;
#endif // SPATIAL_TO_PIM #endif // SPATIAL_TO_PIM

View File

@@ -641,6 +641,9 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) { funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
markOpToRemove(channelNewOp); markOpToRemove(channelNewOp);
if (channelNewOp->use_empty())
return;
spatial::SpatChannelSendOp sendOp; spatial::SpatChannelSendOp sendOp;
spatial::SpatChannelReceiveOp receiveOp; spatial::SpatChannelReceiveOp receiveOp;
spatial::SpatChannelBroadcastSendOp broadcastSendOp; spatial::SpatChannelBroadcastSendOp broadcastSendOp;

View File

@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
} }
}; };
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveOp>(op);
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdAttr());
return success();
}
};
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> { struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) { void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);

View File

@@ -3,10 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps add_pim_library(SpatialOps
SpatialOps.cpp SpatialOps.cpp
Transforms/MergeComputeNode/MergeComputeNodePass.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
DCPGraph/Task.cpp Transforms/MergeComputeNodes/DCPGraph/Task.cpp
DCPGraph/DCPAnalysis.cpp Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -8,7 +8,6 @@
#include <iterator> #include <iterator>
#include "../SpatialOps.hpp"
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"
#include "Graph.hpp" #include "Graph.hpp"
#include "src/Support/TypeUtilities.hpp" #include "src/Support/TypeUtilities.hpp"

View File

@@ -6,7 +6,7 @@
#include <vector> #include <vector>
#include "../SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
struct DCPAnalysisResult { struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute; std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;

View File

@@ -7,11 +7,11 @@
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include "../../../Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"
#include "Graph.hpp" #include "Graph.hpp"
#include "Task.hpp" #include "Task.hpp"
#include "Uniqueworklist.hpp" #include "UniqueWorklist.hpp"
#include "Utils.hpp" #include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) { std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {

View File

@@ -2,7 +2,7 @@
#include "Graph.hpp" #include "Graph.hpp"
#include "Task.hpp" #include "Task.hpp"
#include "Uniqueworklist.hpp" #include "UniqueWorklist.hpp"
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) { std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt; std::optional<Edge_t> oldEdge = std::nullopt;

View File

@@ -5,7 +5,7 @@
#include <optional> #include <optional>
#include <vector> #include <vector>
#include "../SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "Utils.hpp" #include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);

View File

@@ -9,7 +9,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Support/TypeUtilities.hpp" #include "src/Support/TypeUtilities.hpp"
using CPU = int; using CPU = int;

View File

@@ -20,7 +20,7 @@
#include <memory> #include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" #include "DCPGraph/DCPAnalysis.hpp"
using namespace mlir; using namespace mlir;
@@ -81,7 +81,7 @@ public:
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); } ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
}; };
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> { struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
private: private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults; DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
@@ -89,11 +89,11 @@ private:
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap; DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
public: public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
StringRef getArgument() const override { return "pim-merge-node-pass"; } StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
StringRef getDescription() const override { StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total " return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
"execution time"; "execution time";
} }
@@ -346,6 +346,6 @@ private:
} // namespace } // namespace
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); } std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,13 +1,13 @@
add_pim_library(OMPimPasses add_pim_library(OMPimPasses
CountInstructionPass.cpp CountInstructionPass.cpp
MessagePass.cpp MessagePass.cpp
Pim/ConstantFolding/Common.cpp PimCodegen/HostConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp PimCodegen/HostConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp PimCodegen/HostConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp PimCodegen/MaterializeHostConstantsPass.cpp
Pim/VerificationPass.cpp PimCodegen/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp PimCodegen/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass(); std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createMergeComputeNodePass(); std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass(); std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass(); std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass(); std::unique_ptr<mlir::Pass> createPimVerificationPass();

View File

@@ -11,10 +11,10 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> { struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; } StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override { LogicalResult initialize(MLIRContext* context) override {
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
} // namespace } // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); } std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8; return type.getNumElements() * type.getElementTypeBitWidth() / 8;
} }
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> { struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; } StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override { StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops"; return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
} }
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace } // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); } std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -75,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass); registerPass(createPimBufferizationPass);
registerPass(createMergeComputeNodePass); registerPass(createMergeComputeNodesPass);
registerPass(createPimConstantFoldingPass); registerPass(createPimHostConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass); registerPass(createPimMaterializeHostConstantsPass);
registerPass(createPimVerificationPass); registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }

View File

@@ -1,5 +1,3 @@
Rimuovere la logica di bufferizazione da spatial,
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode) Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
AnalisiDCP AnalisiDCP