remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim
This commit is contained in:
@@ -10,8 +10,10 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
@@ -68,6 +70,8 @@ private:
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void markOpToRemove(Operation* op);
|
||||
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
|
||||
annotateChannelCoreIds(funcOp);
|
||||
lowerBroadcastChannelOps(funcOp, rewriter);
|
||||
|
||||
RewritePatternSet channelPatterns(ctx);
|
||||
populateWithGenerated(channelPatterns);
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnOpOperands(returnOp, rewriter);
|
||||
|
||||
@@ -623,6 +637,91 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||
markOpToRemove(channelNewOp);
|
||||
|
||||
spatial::SpatChannelSendOp sendOp;
|
||||
spatial::SpatChannelReceiveOp receiveOp;
|
||||
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
||||
sendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
|
||||
receiveOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
|
||||
broadcastSendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
|
||||
continue;
|
||||
}
|
||||
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
|
||||
}
|
||||
|
||||
if (broadcastSendOp) {
|
||||
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
|
||||
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sendOp || !receiveOp)
|
||||
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
|
||||
|
||||
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
|
||||
});
|
||||
}
|
||||
|
||||
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
|
||||
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
|
||||
|
||||
for (auto sendOp : broadcastSendOps) {
|
||||
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
||||
|
||||
rewriter.setInsertionPoint(sendOp);
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
||||
}
|
||||
|
||||
if (!foundReceiver)
|
||||
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
|
||||
|
||||
rewriter.eraseOp(sendOp);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
|
||||
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
|
||||
|
||||
for (auto receiveOp : broadcastReceiveOps) {
|
||||
rewriter.setInsertionPoint(receiveOp);
|
||||
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
|
||||
Value receivedValue =
|
||||
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveOp, receivedValue);
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
for (auto it : llvm::enumerate(originalOperands)) {
|
||||
|
||||
Reference in New Issue
Block a user