remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
@@ -12,6 +13,16 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||
assert(attr && "required precomputed channel attr is missing");
|
||||
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
@@ -54,6 +65,26 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -2,11 +2,16 @@
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
@@ -21,6 +26,14 @@ namespace onnx_mlir {
|
||||
*/
|
||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -69,4 +69,19 @@ def spatToPimVSoftmax : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel))
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(PimReceiveOp
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes),
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -10,8 +10,10 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
@@ -68,6 +70,8 @@ private:
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void markOpToRemove(Operation* op);
|
||||
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
|
||||
annotateChannelCoreIds(funcOp);
|
||||
lowerBroadcastChannelOps(funcOp, rewriter);
|
||||
|
||||
RewritePatternSet channelPatterns(ctx);
|
||||
populateWithGenerated(channelPatterns);
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnOpOperands(returnOp, rewriter);
|
||||
|
||||
@@ -623,6 +637,91 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||
markOpToRemove(channelNewOp);
|
||||
|
||||
spatial::SpatChannelSendOp sendOp;
|
||||
spatial::SpatChannelReceiveOp receiveOp;
|
||||
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
||||
sendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
|
||||
receiveOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
|
||||
broadcastSendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
|
||||
continue;
|
||||
}
|
||||
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
|
||||
}
|
||||
|
||||
if (broadcastSendOp) {
|
||||
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
|
||||
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sendOp || !receiveOp)
|
||||
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
|
||||
|
||||
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
|
||||
});
|
||||
}
|
||||
|
||||
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
|
||||
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
|
||||
|
||||
for (auto sendOp : broadcastSendOps) {
|
||||
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
||||
|
||||
rewriter.setInsertionPoint(sendOp);
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
||||
}
|
||||
|
||||
if (!foundReceiver)
|
||||
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
|
||||
|
||||
rewriter.eraseOp(sendOp);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
|
||||
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
|
||||
|
||||
for (auto receiveOp : broadcastReceiveOps) {
|
||||
rewriter.setInsertionPoint(receiveOp);
|
||||
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
|
||||
Value receivedValue =
|
||||
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveOp, receivedValue);
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
for (auto it : llvm::enumerate(originalOperands)) {
|
||||
|
||||
@@ -3,7 +3,6 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/SpatialBufferizableOpInterface.cpp
|
||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
||||
DCPGraph/Graph.cpp
|
||||
DCPGraph/Task.cpp
|
||||
|
||||
@@ -1,532 +0,0 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
|
||||
auto resultShape = cast<ShapedType>(resultType);
|
||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||
|
||||
// Alloc an output memref
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousBuffer.getType(),
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto secondUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Operation* otherUser = nullptr;
|
||||
if (firstUser == op)
|
||||
otherUser = secondUser;
|
||||
else if (secondUser == op)
|
||||
otherUser = firstUser;
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return otherUser;
|
||||
}
|
||||
|
||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
// This function requires the existence of ChannelNewOp and the other
|
||||
// Receive/Send operation. However, during bufferization, the first of the
|
||||
// Receive/Send operation that is processed gets removed. As such, we need to
|
||||
// "precompute" the coreId needed for the other op, and save it as attribute
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId)
|
||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
||||
|
||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
|
||||
if (failed(notOpUserOpt))
|
||||
return failure();
|
||||
Operation* notOpUser = *notOpUserOpt;
|
||||
|
||||
// Save the coreId for this op into the other op as attribute
|
||||
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
|
||||
|
||||
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
|
||||
}
|
||||
|
||||
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
|
||||
|
||||
// Input tensor to the compute OP are always read into its local memory
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensor to the compute OP are _never_ written into its local memory
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the compute OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
// Bufferize its block
|
||||
|
||||
auto& block = op->getRegion(0).front();
|
||||
|
||||
return bufferizeBlockSignature(&block, rewriter, options, state);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This can be used for operation that have a single argument, which is a
|
||||
* variadic of tensors, and a single output with the same same shape
|
||||
* Example: VAdd, VSub, VExp
|
||||
*/
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor values into memref values
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Turn Tensor Operands into Memref Operands
|
||||
SmallVector<Value> memrefOperands;
|
||||
memrefOperands.reserve(op->getNumOperands());
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
if (memrefOperands.size() > 2) {
|
||||
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
|
||||
"with 1 or 2 operands, for now.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor value into memref value
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
if (failed(memrefOperandOpt))
|
||||
return failure();
|
||||
auto memrefOperand = *memrefOperandOpt;
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
Value newValue = ToTy::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
cast<OpTy>(op).getWeightIndexAttr(),
|
||||
memrefOperand,
|
||||
outputTensor)
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.recv
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
|
||||
if (failed(dstCoreId))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
|
||||
op,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(dstCoreId.value()));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId) {
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
cast<IntegerAttr>(precomputedOtherCoreId))
|
||||
.getOutput();
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
|
||||
if (!sendOp)
|
||||
continue;
|
||||
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
|
||||
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
|
||||
}
|
||||
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
|
||||
return failure();
|
||||
}();
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastSendOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the broadcast send into one pim.send per broadcast receiver.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
|
||||
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
rewriter.getI32IntegerAttr(dstCoreId));
|
||||
}
|
||||
|
||||
if (!foundReceiver) {
|
||||
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||
|
||||
struct ONNXSigmoidInterface
|
||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -18,7 +18,6 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
@@ -67,8 +66,6 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
||||
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
||||
pim::registerOpBufferizationInterfaces(registry);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user