remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim

This commit is contained in:
NiccoloN
2026-04-14 11:55:19 +02:00
parent 368e340a40
commit 30ee9640d4
8 changed files with 158 additions and 551 deletions

View File

@@ -1,6 +1,7 @@
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include <cassert>
#include <cstddef>
@@ -12,6 +13,16 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
assert(attr && "required precomputed channel attr is missing");
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
}
} // namespace
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
@@ -54,6 +65,26 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
}
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();

View File

@@ -2,11 +2,16 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
@@ -21,6 +26,14 @@ namespace onnx_mlir {
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());

View File

@@ -69,4 +69,19 @@ def spatToPimVSoftmax : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatChannelSendToPimSend : Pat<
(SpatChannelSendOp $channel, $input),
(PimSendOp $input,
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel))
>;
def spatChannelReceiveToPimReceive : Pat<
(SpatChannelReceiveOp:$srcOpRes $channel),
(PimReceiveOp
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes),
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $srcOpRes),
(NativeCodeCall<"onnx_mlir::getSpatialChannelSourceCoreIdAttr($_builder, $0)"> $channel))
>;
#endif // SPATIAL_TO_PIM

View File

@@ -10,8 +10,10 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_os_ostream.h"
@@ -68,6 +70,8 @@ private:
bool useBroadcastOp,
IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void annotateChannelCoreIds(func::FuncOp funcOp);
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
runOnComputeOp(computeOp, rewriter);
}
annotateChannelCoreIds(funcOp);
lowerBroadcastChannelOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
signalPassFailure();
return;
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
@@ -623,6 +637,91 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
operationsToRemove.push_back(op);
}
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
markOpToRemove(channelNewOp);
spatial::SpatChannelSendOp sendOp;
spatial::SpatChannelReceiveOp receiveOp;
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
for (Operation* user : channelNewOp->getUsers()) {
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
sendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
receiveOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
broadcastSendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
continue;
}
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
}
if (broadcastSendOp) {
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
return;
}
if (!sendOp || !receiveOp)
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
});
}
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
for (auto sendOp : broadcastSendOps) {
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
rewriter.setInsertionPoint(sendOp);
bool foundReceiver = false;
for (Operation* user : channelNewOp->getUsers()) {
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
if (!receiveOp)
continue;
foundReceiver = true;
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
}
if (!foundReceiver)
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
rewriter.eraseOp(sendOp);
}
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
for (auto receiveOp : broadcastReceiveOps) {
rewriter.setInsertionPoint(receiveOp);
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
Value receivedValue =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, receivedValue);
}
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
for (auto it : llvm::enumerate(originalOperands)) {