From 30ee9640d4905640d167439bc8c015b08c2f8c77 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 14 Apr 2026 11:55:19 +0200 Subject: [PATCH] remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim --- src/PIM/Conversion/SpatialToPim/Common.cpp | 31 + src/PIM/Conversion/SpatialToPim/Common.hpp | 13 + .../Conversion/SpatialToPim/SpatialToPim.td | 15 + .../SpatialToPim/SpatialToPimPass.cpp | 99 ++++ src/PIM/Dialect/Spatial/CMakeLists.txt | 1 - .../SpatialBufferizableOpInterface.cpp | 532 ------------------ .../SpatialBufferizableOpInterface.hpp | 15 - src/PIM/PimAccelerator.cpp | 3 - 8 files changed, 158 insertions(+), 551 deletions(-) delete mode 100644 src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp delete mode 100644 src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 96aa454..dd146ca 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -1,6 +1,7 @@ #include "mlir/IR/ValueRange.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" #include #include @@ -12,6 +13,16 @@ using namespace mlir; namespace onnx_mlir { +namespace { + +IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) { + auto attr = op->getAttrOfType(attrName); + assert(attr && "required precomputed channel attr is missing"); + return IntegerAttr::get(builder.getI32Type(), attr.getInt()); +} + +} // namespace + size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) { /* EXAMPLE RUN: @@ -54,6 +65,26 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh return returnValue; } +size_t getShapedTypeSizeInBytes(ShapedType shapedType) { + return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; +} + +IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { + return builder.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(cast(value.getType())))); +} + +IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) { + auto channelNewOp = channel.getDefiningOp(); + assert(channelNewOp && "spatial channel value must come from spat.channel_new"); + return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName); +} + +IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) { + auto channelNewOp = channel.getDefiningOp(); + assert(channelNewOp && "spatial channel value must come from spat.channel_new"); + return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName); +} + Operation* getEarliestUserWithinBlock(mlir::Value value) { auto users = value.getUsers(); diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index be8fe5a..99819d9 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -2,11 +2,16 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "llvm/ADT/StringRef.h" + #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { +inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id"; +inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id"; + /** * \brief Get the offset of the ExtractSliceOp based on its static offsets and * its static tensor input. @@ -21,6 +26,14 @@ namespace onnx_mlir { */ size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape); +size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType); + +mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); + +mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel); + +mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel); + template size_t rangeLength(const mlir::iterator_range range) { return std::distance(range.begin(), range.end()); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index a0fbce5..7bcb91d 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -69,4 +69,19 @@ def spatToPimVSoftmax : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; +def spatChannelSendToPimSend : Pat< + (SpatChannelSendOp $channel, $input), + (PimSendOp $input, + (NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input), + (NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)) +>; + +def spatChannelReceiveToPimReceive : Pat< + (SpatChannelReceiveOp:$srcOpRes $channel), + (PimReceiveOp + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes), + (NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes), + (NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel)) +>; + #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 0ea92c4..f9beb4c 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -10,8 +10,10 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_os_ostream.h" @@ -68,6 +70,8 @@ private: bool useBroadcastOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); + void annotateChannelCoreIds(func::FuncOp funcOp); + void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter); void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); @@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() { runOnComputeOp(computeOp, rewriter); } + annotateChannelCoreIds(funcOp); + lowerBroadcastChannelOps(funcOp, rewriter); + + RewritePatternSet channelPatterns(ctx); + populateWithGenerated(channelPatterns); + if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { + signalPassFailure(); + return; + } + enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnOpOperands(returnOp, rewriter); @@ -623,6 +637,91 @@ void SpatialToPimPass::markOpToRemove(Operation* op) { operationsToRemove.push_back(op); } +void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) { + funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) { + markOpToRemove(channelNewOp); + + spatial::SpatChannelSendOp sendOp; + spatial::SpatChannelReceiveOp receiveOp; + spatial::SpatChannelBroadcastSendOp broadcastSendOp; + + for (Operation* user : channelNewOp->getUsers()) { + if (auto op = dyn_cast(user)) { + sendOp = op; + continue; + } + if (auto op = dyn_cast(user)) { + receiveOp = op; + continue; + } + if (auto op = dyn_cast(user)) { + broadcastSendOp = op; + continue; + } + if (auto op = dyn_cast(user)) { + continue; + } + llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering"); + } + + if (broadcastSendOp) { + auto sourceCoreIdAttr = cast(broadcastSendOp->getParentOp()).getCoreIdAttr(); + channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr); + return; + } + + if (!sendOp || !receiveOp) + llvm_unreachable("spat.channel_new must connect exactly one send and one receive"); + + auto sourceCoreIdAttr = cast(sendOp->getParentOp()).getCoreIdAttr(); + auto targetCoreIdAttr = cast(receiveOp->getParentOp()).getCoreIdAttr(); + channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr); + channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr); + }); +} + +void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector broadcastSendOps; + funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); }); + + for (auto sendOp : broadcastSendOps) { + auto channelNewOp = cast(sendOp.getChannel().getDefiningOp()); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); + + rewriter.setInsertionPoint(sendOp); + bool foundReceiver = false; + for (Operation* user : channelNewOp->getUsers()) { + auto receiveOp = dyn_cast(user); + if (!receiveOp) + continue; + + foundReceiver = true; + auto targetCoreIdAttr = cast(receiveOp->getParentOp()).getCoreIdAttr(); + PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); + } + + if (!foundReceiver) + llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive"); + + rewriter.eraseOp(sendOp); + } + + SmallVector broadcastReceiveOps; + funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); }); + + for (auto receiveOp : broadcastReceiveOps) { + rewriter.setInsertionPoint(receiveOp); + auto outputType = cast(receiveOp.getResult().getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); + auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel()); + Value receivedValue = + PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) + .getOutput(); + rewriter.replaceOp(receiveOp, receivedValue); + } +} + void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); for (auto it : llvm::enumerate(originalOperands)) { diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 1a412a6..3efa16b 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -3,7 +3,6 @@ add_onnx_mlir_dialect_doc(spat Spatial.td) add_pim_library(SpatialOps SpatialOps.cpp - Transforms/SpatialBufferizableOpInterface.cpp Transforms/MergeComputeNode/MergeComputeNodePass.cpp DCPGraph/Graph.cpp DCPGraph/Task.cpp diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp deleted file mode 100644 index 3fa2413..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ /dev/null @@ -1,532 +0,0 @@ -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LLVM.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_ostream.h" - -#include - -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; -using namespace bufferization; - -namespace onnx_mlir { -namespace spatial { - -memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) { - auto resultShape = cast(resultType); - auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType()); - - // Alloc an output memref - return memref::AllocOp::create(rewriter, loc, memrefResultType); -} - -Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { - if (succeeded(resolveContiguousAddress(memrefValue))) - return memrefValue; - - auto shapedType = cast(memrefValue.getType()); - auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter); - auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; - - return pim::PimMemCopyOp::create(rewriter, - loc, - contiguousBuffer.getType(), - contiguousBuffer, - memrefValue, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(sizeInBytes)) - .getOutput(); -} - -const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); - -static FailureOr getOtherEndOfChannel(Operation* op, bool opIsReceive) { - auto channelNewOp = op->getOperand(0).getDefiningOp(); - if (!channelNewOp) { - op->emitError("User of Channel must have the first operand created by ChannelNewOp."); - return failure(); - } - - auto channelUsers = channelNewOp->getUsers(); - auto usersIterator = channelUsers.begin(); - auto firstUser = *usersIterator; - ++usersIterator; - if (usersIterator == channelUsers.end()) { - op->emitError("Operand generated by ChannelNewOp must have two users, only one found."); - channelNewOp->dump(); - op->dump(); - channelNewOp->getParentOp()->dump(); - return failure(); - } - - auto secondUser = *usersIterator; - ++usersIterator; - if (usersIterator != channelUsers.end()) { - op->emitError("Operand generated by ChannelNewOp must have two users, more than two found."); - return failure(); - } - - Operation* otherUser = nullptr; - if (firstUser == op) - otherUser = secondUser; - else if (secondUser == op) - otherUser = firstUser; - else { - op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op."); - return failure(); - } - - if (opIsReceive && !isa(otherUser)) { - op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp."); - return failure(); - } - - if (!opIsReceive && !isa(otherUser)) { - op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp."); - return failure(); - } - - return otherUser; -} - -llvm::FailureOr getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { - - // This function requires the existence of ChannelNewOp and the other - // Receive/Send operation. However, during bufferization, the first of the - // Receive/Send operation that is processed gets removed. As such, we need to - // "precompute" the coreId needed for the other op, and save it as attribute - auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME); - if (precomputedOtherCoreId) - return cast(precomputedOtherCoreId).getInt(); - - auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive); - if (failed(notOpUserOpt)) - return failure(); - Operation* notOpUser = *notOpUserOpt; - - // Save the coreId for this op into the other op as attribute - auto opCoreIdAttr = cast(op->getParentOp()).getCoreIdAttr(); - notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr); - - return cast(notOpUser->getParentOp()).getCoreId(); -} - -struct WComputeOpInterface : BufferizableOpInterface::ExternalModel { - - // Input tensor to the compute OP are always read into its local memory - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } - - // Input tensor to the compute OP are _never_ written into its local memory - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // In general, no tensor is aliased with any other tensor in the compute OP - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - // TODO: Is it an empty list or a list of "UNKNOWN" values? - return {}; - } - - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - // Bufferize its block - - auto& block = op->getRegion(0).front(); - - return bufferizeBlockSignature(&block, rewriter, options, state); - } -}; - -/* - * This can be used for operation that have a single argument, which is a - * variadic of tensors, and a single output with the same same shape - * Example: VAdd, VSub, VExp - */ -template -struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel { - - // Input tensors to the OP are always read - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } - - // Input tensors to the OP are _never_ written - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // In general, no tensor is aliased with any other tensor in the OP - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return {}; - } - - // Cast tensor values into memref values - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - - // Turn Tensor Operands into Memref Operands - SmallVector memrefOperands; - memrefOperands.reserve(op->getNumOperands()); - for (auto operand : op->getOperands()) { - auto memref = getBuffer(rewriter, operand, options, state); - if (failed(memref)) - return failure(); - memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter)); - } - - // TODO: Support addiction with more than 2 operands - if (memrefOperands.size() > 2) { - op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs " - "with 1 or 2 operands, for now."); - return failure(); - } - - // Alloc an output memref - Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - - memrefOperands.push_back(outputTensor); - - Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput(); - - replaceOpWithBufferizedValues(rewriter, op, newValue); - - return success(); - } -}; - -template -struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel { - - // Input tensors to the OP are always read - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } - - // Input tensors to the OP are _never_ written - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // In general, no tensor is aliased with any other tensor in the OP - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return {}; - } - - // Cast tensor value into memref value - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state); - if (failed(memrefOperandOpt)) - return failure(); - auto memrefOperand = *memrefOperandOpt; - - // Alloc an output memref - Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - - Value newValue = ToTy::create(rewriter, - op->getLoc(), - outputTensor.getType(), - cast(op).getWeightIndexAttr(), - memrefOperand, - outputTensor) - .getOutput(); - - replaceOpWithBufferizedValues(rewriter, op, newValue); - - return success(); - } -}; - -struct ChannelReceiveOpInterface -: BufferizableOpInterface::ExternalModel { - - // Input value is the channel (not read/written, its more of an attribute) - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // See above - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // See above - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - // TODO: Is it an empty list or a list of "UNKNOWN" values? - return {}; - } - - /* - * Turn the channel receive to pim.recv - */ - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - - auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - - auto numElements = cast(outputTensor.getType()).getNumElements(); - auto elementSize = cast(outputTensor.getType()).getElementTypeBitWidth() / 8; - - auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter); - if (failed(srcCoreId)) - return failure(); - - Value newValue = pim::PimReceiveOp::create(rewriter, - op->getLoc(), - outputTensor.getType(), - outputTensor, - rewriter.getI32IntegerAttr(numElements * elementSize), - rewriter.getI32IntegerAttr(srcCoreId.value())) - .getOutput(); - - replaceOpWithBufferizedValues(rewriter, op, newValue); - - return success(); - } -}; - -struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel { - - // First input is channel (not read/writter) second input is Tensor to send, - // which is read - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return opOperand.getOperandNumber() == 2; - } - - // See above (both non-written) - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // See above - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - // TODO: Is it an empty list or a list of "UNKNOWN" values? - return {}; - } - - /* - * Turn the channel send to pim.send - */ - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto srcTensor = op->getOperand(1); - - auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); - if (failed(srcTensorOpt)) - return failure(); - auto srcMemRef = *srcTensorOpt; - - auto numElements = cast(srcTensor.getType()).getNumElements(); - auto elementSize = cast(srcTensor.getType()).getElementTypeBitWidth() / 8; - - auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter); - if (failed(dstCoreId)) - return failure(); - - replaceOpWithNewBufferizedOp(rewriter, - op, - srcMemRef, - rewriter.getI32IntegerAttr(numElements * elementSize), - rewriter.getI32IntegerAttr(dstCoreId.value())); - - return success(); - } -}; - -struct ChannelBroadcastReceiveOpInterface -: BufferizableOpInterface::ExternalModel { - - // Input value is the channel (not read/written, its more of an attribute) - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // See above - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // See above - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - // TODO: Is it an empty list or a list of "UNKNOWN" values? - return {}; - } - - /* - * Turn the broadcast receive into a regular pim.receive from the broadcaster. - */ - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - - auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - - auto numElements = cast(outputTensor.getType()).getNumElements(); - auto elementSize = cast(outputTensor.getType()).getElementTypeBitWidth() / 8; - - auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME); - if (precomputedOtherCoreId) { - Value newValue = pim::PimReceiveOp::create(rewriter, - op->getLoc(), - outputTensor.getType(), - outputTensor, - rewriter.getI32IntegerAttr(numElements * elementSize), - cast(precomputedOtherCoreId)) - .getOutput(); - replaceOpWithBufferizedValues(rewriter, op, newValue); - return success(); - } - - auto channelNewOp = op->getOperand(0).getDefiningOp(); - if (!channelNewOp) { - op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand"); - return failure(); - } - - auto srcCoreId = [&]() -> FailureOr { - for (Operation* user : channelNewOp->getUsers()) { - auto sendOp = dyn_cast(user); - if (!sendOp) - continue; - auto sendCoreIdAttr = cast(sendOp->getParentOp()).getCoreIdAttr(); - op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr); - return cast(sendOp->getParentOp()).getCoreId(); - } - op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp"); - return failure(); - }(); - if (failed(srcCoreId)) - return failure(); - - Value newValue = pim::PimReceiveOp::create(rewriter, - op->getLoc(), - outputTensor.getType(), - outputTensor, - rewriter.getI32IntegerAttr(numElements * elementSize), - rewriter.getI32IntegerAttr(srcCoreId.value())) - .getOutput(); - - replaceOpWithBufferizedValues(rewriter, op, newValue); - - return success(); - } -}; - -struct ChannelBroadcastSendOpInterface -: BufferizableOpInterface::ExternalModel { - - // First input is channel (not read/writter) second input is Tensor to send, - // which is read - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return opOperand.getOperandNumber() == 2; - } - - // See above (both non-written) - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } - - // See above - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - // TODO: Is it an empty list or a list of "UNKNOWN" values? - return {}; - } - - /* - * Turn the broadcast send into one pim.send per broadcast receiver. - */ - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto srcTensor = op->getOperand(1); - - auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); - if (failed(srcTensorOpt)) - return failure(); - auto srcMemRef = *srcTensorOpt; - - auto channelNewOp = op->getOperand(0).getDefiningOp(); - if (!channelNewOp) { - op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand"); - return failure(); - } - - auto srcType = cast(srcTensor.getType()); - auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8; - auto srcCoreIdAttr = cast(op->getParentOp()).getCoreIdAttr(); - - rewriter.setInsertionPoint(op); - bool foundReceiver = false; - for (Operation* user : channelNewOp->getUsers()) { - auto receiveOp = dyn_cast(user); - if (!receiveOp) - continue; - - foundReceiver = true; - auto dstCoreId = cast(receiveOp->getParentOp()).getCoreId(); - receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr); - pim::PimSendOp::create(rewriter, - op->getLoc(), - srcMemRef, - rewriter.getI32IntegerAttr(sizeInBytes), - rewriter.getI32IntegerAttr(dstCoreId)); - } - - if (!foundReceiver) { - op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp"); - return failure(); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -struct VAddOpInterfaceFromTemplate -: VariadicArgumentElementWiseOpInterface {}; - -struct WVMMOpInterface : WeightedMultiplicationsOpInterface {}; - -struct WMVMOpInterface : WeightedMultiplicationsOpInterface {}; - -struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface {}; - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { - registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) { - SpatWeightedCompute::attachInterface(*ctx); - SpatVAddOp::attachInterface(*ctx); - SpatWeightedVMMOp::attachInterface(*ctx); - SpatWeightedMVMOp::attachInterface(*ctx); - SpatVMaxOp::attachInterface(*ctx); - SpatChannelReceiveOp::attachInterface(*ctx); - SpatChannelSendOp::attachInterface(*ctx); - SpatChannelBroadcastReceiveOp::attachInterface(*ctx); - SpatChannelBroadcastSendOp::attachInterface(*ctx); - }); -} - -struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface {}; - -struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface {}; - -struct ONNXSigmoidInterface -: VariadicArgumentElementWiseOpInterface {}; - -void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { - registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) { - ONNXReluOp::attachInterface(*ctx); - ONNXTanhOp::attachInterface(*ctx); - ONNXSigmoidOp::attachInterface(*ctx); - }); -} - -} // namespace spatial -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp deleted file mode 100644 index 9013e74..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "mlir/IR/DialectRegistry.h" - -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -namespace onnx_mlir { -namespace spatial { - -void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry); - -void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry); - -} // namespace spatial -} // namespace onnx_mlir diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 29ceb15..6b9ce54 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -18,7 +18,6 @@ #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/PimAccelerator.hpp" #include "src/Compiler/CompilerUtils.hpp" @@ -67,8 +66,6 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const { mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); mlir::scf::registerBufferizableOpInterfaceExternalModels(registry); - spatial::registerBufferizableOpInterfaceExternalModels(registry); - spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); pim::registerOpBufferizationInterfaces(registry); }