fix missed failing tests for channels
moderate refactor
This commit is contained in:
@@ -19,7 +19,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -85,6 +86,25 @@ IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value chan
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -34,6 +34,13 @@ mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed source core id">;
|
||||
|
||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed target core id">;
|
||||
|
||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
@@ -73,15 +84,14 @@ def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel))
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(PimReceiveOp
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes),
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
|
||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -641,6 +641,9 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||
markOpToRemove(channelNewOp);
|
||||
|
||||
if (channelNewOp->use_empty())
|
||||
return;
|
||||
|
||||
spatial::SpatChannelSendOp sendOp;
|
||||
spatial::SpatChannelReceiveOp receiveOp;
|
||||
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||
|
||||
Reference in New Issue
Block a user