fix missed failing tests for channels

moderate refactor
This commit is contained in:
NiccoloN
2026-04-14 12:26:41 +02:00
parent 30ee9640d4
commit eade488d13
30 changed files with 115 additions and 50 deletions

View File

@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createMergeComputeNodePass());
pm.addPass(createMergeComputeNodesPass());
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());

View File

@@ -19,7 +19,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"

View File

@@ -7,6 +7,7 @@
#include <cstddef>
#include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
@@ -85,6 +86,25 @@ IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value chan
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
}
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
}
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();

View File

@@ -34,6 +34,13 @@ mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
mlir::Value createPimReceiveFromSpatialChannel(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());

View File

@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def HasSpatialChannelSourceCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
"spatial channel has precomputed source core id">;
def HasSpatialChannelTargetCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
"spatial channel has precomputed target core id">;
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
@@ -73,15 +84,14 @@ def spatChannelSendToPimSend : Pat<
(SpatChannelSendOp $channel, $input),
(PimSendOp $input,
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel))
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
[(HasSpatialChannelTargetCoreIdAttr $channel)]
>;
def spatChannelReceiveToPimReceive : Pat<
(SpatChannelReceiveOp:$srcOpRes $channel),
(PimReceiveOp
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes),
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
[(HasSpatialChannelSourceCoreIdAttr $channel)]
>;
#endif // SPATIAL_TO_PIM

View File

@@ -641,6 +641,9 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
markOpToRemove(channelNewOp);
if (channelNewOp->use_empty())
return;
spatial::SpatChannelSendOp sendOp;
spatial::SpatChannelReceiveOp receiveOp;
spatial::SpatChannelBroadcastSendOp broadcastSendOp;

View File

@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
}
};
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveOp>(op);
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdAttr());
return success();
}
};
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);

View File

@@ -3,10 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps
SpatialOps.cpp
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
DCPGraph/Graph.cpp
DCPGraph/Task.cpp
DCPGraph/DCPAnalysis.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -8,7 +8,6 @@
#include <iterator>
#include "../SpatialOps.hpp"
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "src/Support/TypeUtilities.hpp"

View File

@@ -6,7 +6,7 @@
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;

View File

@@ -7,11 +7,11 @@
#include <fstream>
#include <vector>
#include "../../../Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "Task.hpp"
#include "Uniqueworklist.hpp"
#include "UniqueWorklist.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {

View File

@@ -2,7 +2,7 @@
#include "Graph.hpp"
#include "Task.hpp"
#include "Uniqueworklist.hpp"
#include "UniqueWorklist.hpp"
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt;

View File

@@ -5,7 +5,7 @@
#include <optional>
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);

View File

@@ -9,7 +9,7 @@
#include <utility>
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Support/TypeUtilities.hpp"
using CPU = int;

View File

@@ -20,7 +20,7 @@
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
#include "DCPGraph/DCPAnalysis.hpp"
using namespace mlir;
@@ -81,7 +81,7 @@ public:
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
};
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
@@ -89,11 +89,11 @@ private:
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
StringRef getArgument() const override { return "pim-merge-node-pass"; }
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
"execution time";
}
@@ -346,6 +346,6 @@ private:
} // namespace
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
} // namespace onnx_mlir

View File

@@ -1,13 +1,13 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
Pim/ConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp
Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
PimCodegen/HostConstantFolding/Common.cpp
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
PimCodegen/MaterializeHostConstantsPass.cpp
PimCodegen/VerificationPass.cpp
PimCodegen/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();

View File

@@ -11,10 +11,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir

View File

@@ -75,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass);
registerPass(createMergeComputeNodePass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass);
registerPass(createMergeComputeNodesPass);
registerPass(createPimHostConstantFoldingPass);
registerPass(createPimMaterializeHostConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass);
}

View File

@@ -1,5 +1,3 @@
Rimuovere la logica di bufferizazione da spatial,
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
AnalisiDCP