remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim

This commit is contained in:
NiccoloN
2026-04-14 11:55:19 +02:00
parent 368e340a40
commit 30ee9640d4
8 changed files with 158 additions and 551 deletions

View File

@@ -10,8 +10,10 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_os_ostream.h"
@@ -68,6 +70,8 @@ private:
bool useBroadcastOp,
IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void annotateChannelCoreIds(func::FuncOp funcOp);
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
runOnComputeOp(computeOp, rewriter);
}
annotateChannelCoreIds(funcOp);
lowerBroadcastChannelOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
signalPassFailure();
return;
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
@@ -623,6 +637,91 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
operationsToRemove.push_back(op);
}
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
markOpToRemove(channelNewOp);
spatial::SpatChannelSendOp sendOp;
spatial::SpatChannelReceiveOp receiveOp;
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
for (Operation* user : channelNewOp->getUsers()) {
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
sendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
receiveOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
broadcastSendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
continue;
}
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
}
if (broadcastSendOp) {
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
return;
}
if (!sendOp || !receiveOp)
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
});
}
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
for (auto sendOp : broadcastSendOps) {
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
rewriter.setInsertionPoint(sendOp);
bool foundReceiver = false;
for (Operation* user : channelNewOp->getUsers()) {
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
if (!receiveOp)
continue;
foundReceiver = true;
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
}
if (!foundReceiver)
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
rewriter.eraseOp(sendOp);
}
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
for (auto receiveOp : broadcastReceiveOps) {
rewriter.setInsertionPoint(receiveOp);
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
Value receivedValue =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, receivedValue);
}
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
for (auto it : llvm::enumerate(originalOperands)) {