fix missed failing tests for channels
moderate refactor
This commit is contained in:
@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPim) {
|
||||
pm.addPass(createMergeComputeNodePass());
|
||||
pm.addPass(createMergeComputeNodesPass());
|
||||
pm.addPass(createSpatialToPimPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createPimMaterializeConstantsPass());
|
||||
pm.addPass(createPimHostConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -85,6 +86,25 @@ IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value chan
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -34,6 +34,13 @@ mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed source core id">;
|
||||
|
||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed target core id">;
|
||||
|
||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
@@ -73,15 +84,14 @@ def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel))
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(PimReceiveOp
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes),
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
|
||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -641,6 +641,9 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||
markOpToRemove(channelNewOp);
|
||||
|
||||
if (channelNewOp->use_empty())
|
||||
return;
|
||||
|
||||
spatial::SpatChannelSendOp sendOp;
|
||||
spatial::SpatChannelReceiveOp receiveOp;
|
||||
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||
|
||||
@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveOp>(op);
|
||||
|
||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
|
||||
@@ -3,10 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
||||
DCPGraph/Graph.cpp
|
||||
DCPGraph/Task.cpp
|
||||
DCPGraph/DCPAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
#include <iterator>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||
@@ -7,11 +7,11 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "Uniqueworklist.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "Graph.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "Uniqueworklist.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
|
||||
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
using CPU = int;
|
||||
@@ -20,7 +20,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -81,7 +81,7 @@ public:
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
};
|
||||
|
||||
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
|
||||
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
@@ -89,11 +89,11 @@ private:
|
||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-merge-node-pass"; }
|
||||
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
|
||||
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||
"execution time";
|
||||
}
|
||||
|
||||
@@ -346,6 +346,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
|
||||
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,13 +1,13 @@
|
||||
add_pim_library(OMPimPasses
|
||||
CountInstructionPass.cpp
|
||||
MessagePass.cpp
|
||||
Pim/ConstantFolding/Common.cpp
|
||||
Pim/ConstantFolding/Patterns/Constant.cpp
|
||||
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
||||
Pim/ConstantFolding/Patterns/Subview.cpp
|
||||
Pim/MaterializeConstantsPass.cpp
|
||||
Pim/VerificationPass.cpp
|
||||
Pim/EmitPimJsonPass.cpp
|
||||
PimCodegen/HostConstantFolding/Common.cpp
|
||||
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
|
||||
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
|
||||
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
|
||||
PimCodegen/MaterializeHostConstantsPass.cpp
|
||||
PimCodegen/VerificationPass.cpp
|
||||
PimCodegen/EmitPimJsonPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
|
||||
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
||||
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||
return std::make_unique<MaterializeHostConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -75,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
registerPass(createSpatialToGraphvizPass);
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createPimBufferizationPass);
|
||||
registerPass(createMergeComputeNodePass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimMaterializeConstantsPass);
|
||||
registerPass(createMergeComputeNodesPass);
|
||||
registerPass(createPimHostConstantFoldingPass);
|
||||
registerPass(createPimMaterializeHostConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
Rimuovere la logica di bufferizazione da spatial,
|
||||
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
|
||||
|
||||
AnalisiDCP
|
||||
|
||||
Reference in New Issue
Block a user