remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim
This commit is contained in:
@@ -3,7 +3,6 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/SpatialBufferizableOpInterface.cpp
|
||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
||||
DCPGraph/Graph.cpp
|
||||
DCPGraph/Task.cpp
|
||||
|
||||
@@ -1,532 +0,0 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
|
||||
auto resultShape = cast<ShapedType>(resultType);
|
||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||
|
||||
// Alloc an output memref
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousBuffer.getType(),
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto secondUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Operation* otherUser = nullptr;
|
||||
if (firstUser == op)
|
||||
otherUser = secondUser;
|
||||
else if (secondUser == op)
|
||||
otherUser = firstUser;
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return otherUser;
|
||||
}
|
||||
|
||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
// This function requires the existence of ChannelNewOp and the other
|
||||
// Receive/Send operation. However, during bufferization, the first of the
|
||||
// Receive/Send operation that is processed gets removed. As such, we need to
|
||||
// "precompute" the coreId needed for the other op, and save it as attribute
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId)
|
||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
||||
|
||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
|
||||
if (failed(notOpUserOpt))
|
||||
return failure();
|
||||
Operation* notOpUser = *notOpUserOpt;
|
||||
|
||||
// Save the coreId for this op into the other op as attribute
|
||||
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
|
||||
|
||||
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
|
||||
}
|
||||
|
||||
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
|
||||
|
||||
// Input tensor to the compute OP are always read into its local memory
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensor to the compute OP are _never_ written into its local memory
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the compute OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
// Bufferize its block
|
||||
|
||||
auto& block = op->getRegion(0).front();
|
||||
|
||||
return bufferizeBlockSignature(&block, rewriter, options, state);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This can be used for operation that have a single argument, which is a
|
||||
* variadic of tensors, and a single output with the same same shape
|
||||
* Example: VAdd, VSub, VExp
|
||||
*/
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor values into memref values
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Turn Tensor Operands into Memref Operands
|
||||
SmallVector<Value> memrefOperands;
|
||||
memrefOperands.reserve(op->getNumOperands());
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
if (memrefOperands.size() > 2) {
|
||||
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
|
||||
"with 1 or 2 operands, for now.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor value into memref value
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
if (failed(memrefOperandOpt))
|
||||
return failure();
|
||||
auto memrefOperand = *memrefOperandOpt;
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
Value newValue = ToTy::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
cast<OpTy>(op).getWeightIndexAttr(),
|
||||
memrefOperand,
|
||||
outputTensor)
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.recv
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
|
||||
if (failed(dstCoreId))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
|
||||
op,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(dstCoreId.value()));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId) {
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
cast<IntegerAttr>(precomputedOtherCoreId))
|
||||
.getOutput();
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
|
||||
if (!sendOp)
|
||||
continue;
|
||||
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
|
||||
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
|
||||
}
|
||||
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
|
||||
return failure();
|
||||
}();
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastSendOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the broadcast send into one pim.send per broadcast receiver.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
|
||||
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
rewriter.getI32IntegerAttr(dstCoreId));
|
||||
}
|
||||
|
||||
if (!foundReceiver) {
|
||||
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||
|
||||
struct ONNXSigmoidInterface
|
||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user