simple convolutions now work :)

This commit is contained in:
NiccoloN
2026-03-20 21:17:02 +01:00
parent 6933804003
commit ca2e1645bb

View File

@@ -4,6 +4,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
@@ -52,20 +53,21 @@ private:
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
void addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
bool useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
@@ -76,6 +78,34 @@ private:
} // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
op);
}
static size_t countComputeLeafUsers(Value value) {
size_t leafUserCount = 0;
auto walkUses = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (isa<spatial::SpatWeightedCompute>(owner)) {
leafUserCount++;
continue;
}
if (!isChannelUseChainOp(owner))
llvm_unreachable("Channel use chain contains unsupported op");
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
self(owner->getResult(0), self);
}
};
walkUses(value, walkUses);
return leafUserCount;
}
void SpatialToPimPass::runOnOperation() {
coreId = 1;
ModuleOp moduleOp = getOperation();
@@ -103,7 +133,10 @@ void SpatialToPimPass::runOnOperation() {
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter);
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
signalPassFailure();
return;
}
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp);
@@ -233,14 +266,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Receive value through the channel
// If this result is used by more than one user, then use a "Broadcast"
// channel operation. However, there is a special case: we have a single
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
// will detect this case and update `useBroadcastOp` accordingly.
bool useBroadcastOp = (numResultUses > 1);
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
// 2. Receive value through the channel. Broadcast is needed whenever the
// value eventually reaches more than one compute consumer, even through a
// chain of view-like ops.
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
// 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue);
@@ -327,7 +357,7 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
}
}
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
@@ -359,7 +389,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
@@ -369,7 +400,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
funcOp.eraseArgument(i);
if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
@@ -383,6 +415,9 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
@@ -416,12 +451,15 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
return success();
}
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter) {
auto& computeBlock = computeOp.getRegion().front();
@@ -434,68 +472,68 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue;
if (useBroadcastOp)
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
blockArg.replaceAllUsesWith(receivedValue);
Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) {
SmallVector<Operation*> clonedChain;
Value currentValue = consumerValue;
while (currentValue != channelSourceOp) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || !isChannelUseChainOp(definingOp))
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
clonedChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
IRMapping mapping;
mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) {
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
markOpToRemove(op);
}
replacementValue = cast<Value>(mapping.lookup(consumerValue));
}
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue);
}
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
bool useBroadcastOp,
IRRewriter& rewriter) {
auto sourceOpUses = channelSourceOp.getUses();
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
if (useBroadcastOp == false) {
// if useBroadcastOp is false, then sourceOp must have only one user
assert(rangeLength(sourceOpUses) == 1);
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
auto reshapeOpUses = reshapeOp.getOutput().getUses();
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
if (reshapeOpUsesCount > 1)
useBroadcastOp = true;
}
}
for (auto& resultUse : sourceOpUses) {
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
if (computeUser) {
replaceBlockArgumentWithRecvOp(
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
continue;
}
if (!computeUser) {
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
if (!reshapeOp) {
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
resultUse.getOwner()->dump();
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
}
// The tensorType now becomes the one of the reshapeOp
channelTensorType = reshapeOp.getResult().getType();
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
if (!computeUser)
llvm_unreachable("ReshapeOp users must be ComputeOps");
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
replaceBlockArgumentWithRecvOp(
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
continue;
}
// Remove the reshapeOp, so that the sourceOp has no users
operationsToRemove.push_back(reshapeOp);
if (!isChannelUseChainOp(owner))
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
markOpToRemove(owner);
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
self(owner->getResult(0), self);
}
}
};
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
}
void SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
@@ -527,15 +565,10 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
auto tensorType = receiveOp.getType();
Value receiveRes = receiveOp.getResult();
// Check if the receiveOp value has more than one user
auto receiveUses = receiveRes.getUses();
auto receiveUsesCount = rangeLength(receiveUses);
assert(receiveUsesCount > 0);
bool useBroadcastOp = receiveUsesCount > 1;
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
if (useBroadcastOp) {
// When receiving, we actually noticed that the value has more than one