simple convolutions now work :)

This commit is contained in:
NiccoloN
2026-03-20 21:17:02 +01:00
parent 6933804003
commit ca2e1645bb

View File

@@ -4,6 +4,7 @@
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@@ -52,20 +53,21 @@ private:
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp, void addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& channelTensorType, bool useBroadcastOp,
bool& useBroadcastOp,
IRRewriter& rewriter); IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp, bool useBroadcastOp,
IRRewriter& rewriter); IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
@@ -76,6 +78,34 @@ private:
} // namespace } // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
op);
}
static size_t countComputeLeafUsers(Value value) {
size_t leafUserCount = 0;
auto walkUses = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (isa<spatial::SpatWeightedCompute>(owner)) {
leafUserCount++;
continue;
}
if (!isChannelUseChainOp(owner))
llvm_unreachable("Channel use chain contains unsupported op");
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
self(owner->getResult(0), self);
}
};
walkUses(value, walkUses);
return leafUserCount;
}
void SpatialToPimPass::runOnOperation() { void SpatialToPimPass::runOnOperation() {
coreId = 1; coreId = 1;
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -103,7 +133,10 @@ void SpatialToPimPass::runOnOperation() {
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter); addResultBuffer(returnOp, rewriter);
allocateAndInitializeCoreLocalVariables(funcOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
signalPassFailure();
return;
}
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) { for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp); operationsToRemove.push_back(receiveOp);
@@ -233,14 +266,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Receive value through the channel // 2. Receive value through the channel. Broadcast is needed whenever the
// If this result is used by more than one user, then use a "Broadcast" // value eventually reaches more than one compute consumer, even through a
// channel operation. However, there is a special case: we have a single // chain of view-like ops.
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this bool useBroadcastOp = countComputeLeafUsers(result) > 1;
// case, we need to use a "Broadcast" channel operation. `addReceiveOps` addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
// will detect this case and update `useBroadcastOp` accordingly.
bool useBroadcastOp = (numResultUses > 1);
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
// 3. Send the value through the channel // 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
@@ -327,7 +357,7 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
} }
} }
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
@@ -359,7 +389,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType()); ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc); if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1); BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front(); Block& block = funcOp.getBody().front();
@@ -369,7 +400,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
inputTensors.push_back(toTensorOp); inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp);
funcOp.eraseArgument(i); if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
} }
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove; llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
@@ -383,6 +415,9 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) { if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource()); tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape(); ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets(); ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes(); ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
@@ -416,12 +451,15 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
for (auto sliceOp : sliceOpsToRemove) for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty()) if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp); rewriter.eraseOp(sliceOp);
return success();
} }
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp, bool useBroadcastOp,
IRRewriter& rewriter) { IRRewriter& rewriter) {
auto& computeBlock = computeOp.getRegion().front(); auto& computeBlock = computeOp.getRegion().front();
@@ -434,68 +472,68 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue; Value receivedValue;
if (useBroadcastOp) if (useBroadcastOp)
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else else
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
blockArg.replaceAllUsesWith(receivedValue); Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) {
SmallVector<Operation*> clonedChain;
Value currentValue = consumerValue;
while (currentValue != channelSourceOp) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || !isChannelUseChainOp(definingOp))
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
clonedChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
} }
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp, IRMapping mapping;
mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) {
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
markOpToRemove(op);
}
replacementValue = cast<Value>(mapping.lookup(consumerValue));
}
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue);
}
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& channelTensorType, bool useBroadcastOp,
bool& useBroadcastOp,
IRRewriter& rewriter) { IRRewriter& rewriter) {
auto sourceOpUses = channelSourceOp.getUses(); auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users Operation* owner = use.getOwner();
if (useBroadcastOp == false) { if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
// if useBroadcastOp is false, then sourceOp must have only one user
assert(rangeLength(sourceOpUses) == 1);
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
auto reshapeOpUses = reshapeOp.getOutput().getUses();
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
if (reshapeOpUsesCount > 1)
useBroadcastOp = true;
}
}
for (auto& resultUse : sourceOpUses) {
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
if (computeUser) {
replaceBlockArgumentWithRecvOp( replaceBlockArgumentWithRecvOp(
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
continue; continue;
} }
if (!computeUser) { if (!isChannelUseChainOp(owner))
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner()); llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
if (!reshapeOp) {
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump(); markOpToRemove(owner);
resultUse.getOwner()->dump(); assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); self(owner->getResult(0), self);
}
};
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
} }
// The tensorType now becomes the one of the reshapeOp void SpatialToPimPass::markOpToRemove(Operation* op) {
channelTensorType = reshapeOp.getResult().getType(); if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
if (!computeUser)
llvm_unreachable("ReshapeOp users must be ComputeOps");
replaceBlockArgumentWithRecvOp(
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
}
// Remove the reshapeOp, so that the sourceOp has no users
operationsToRemove.push_back(reshapeOp);
}
}
} }
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
@@ -527,15 +565,10 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt); auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
auto tensorType = receiveOp.getType();
Value receiveRes = receiveOp.getResult(); Value receiveRes = receiveOp.getResult();
// Check if the receiveOp value has more than one user bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
auto receiveUses = receiveRes.getUses(); addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
auto receiveUsesCount = rangeLength(receiveUses);
assert(receiveUsesCount > 0);
bool useBroadcastOp = receiveUsesCount > 1;
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
if (useBroadcastOp) { if (useBroadcastOp) {
// When receiving, we actually noticed that the value has more than one // When receiving, we actually noticed that the value has more than one