simple convolutions now work :)
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -52,20 +53,21 @@ private:
|
||||
|
||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
|
||||
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void addReceiveOps(Value& channelSourceOp,
|
||||
void addReceiveOps(Value channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& channelTensorType,
|
||||
bool& useBroadcastOp,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& tensorType,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void markOpToRemove(Operation* op);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
@@ -76,6 +78,34 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||
op);
|
||||
}
|
||||
|
||||
static size_t countComputeLeafUsers(Value value) {
|
||||
size_t leafUserCount = 0;
|
||||
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||
leafUserCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isChannelUseChainOp(owner))
|
||||
llvm_unreachable("Channel use chain contains unsupported op");
|
||||
|
||||
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||
self(owner->getResult(0), self);
|
||||
}
|
||||
};
|
||||
|
||||
walkUses(value, walkUses);
|
||||
return leafUserCount;
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -103,7 +133,10 @@ void SpatialToPimPass::runOnOperation() {
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addResultBuffer(returnOp, rewriter);
|
||||
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||
operationsToRemove.push_back(receiveOp);
|
||||
@@ -233,14 +266,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Receive value through the channel
|
||||
// If this result is used by more than one user, then use a "Broadcast"
|
||||
// channel operation. However, there is a special case: we have a single
|
||||
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
|
||||
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
|
||||
// will detect this case and update `useBroadcastOp` accordingly.
|
||||
bool useBroadcastOp = (numResultUses > 1);
|
||||
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
|
||||
// 2. Receive value through the channel. Broadcast is needed whenever the
|
||||
// value eventually reaches more than one compute consumer, even through a
|
||||
// chain of view-like ops.
|
||||
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
|
||||
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
|
||||
|
||||
// 3. Send the value through the channel
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
@@ -327,7 +357,7 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
|
||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||
@@ -359,7 +389,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||
|
||||
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
|
||||
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
||||
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||
|
||||
Block& block = funcOp.getBody().front();
|
||||
@@ -369,7 +400,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
inputTensors.push_back(toTensorOp);
|
||||
|
||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||
funcOp.eraseArgument(i);
|
||||
if (failed(funcOp.eraseArgument(i)))
|
||||
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
||||
}
|
||||
|
||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||
@@ -383,6 +415,9 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||
@@ -416,12 +451,15 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
for (auto sliceOp : sliceOpsToRemove)
|
||||
if (sliceOp->getUses().empty())
|
||||
rewriter.eraseOp(sliceOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& tensorType,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter) {
|
||||
auto& computeBlock = computeOp.getRegion().front();
|
||||
@@ -434,68 +472,68 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
Value receivedValue;
|
||||
if (useBroadcastOp)
|
||||
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
||||
receivedValue =
|
||||
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
else
|
||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
|
||||
blockArg.replaceAllUsesWith(receivedValue);
|
||||
Value replacementValue = receivedValue;
|
||||
if (consumerValue != channelSourceOp) {
|
||||
SmallVector<Operation*> clonedChain;
|
||||
Value currentValue = consumerValue;
|
||||
while (currentValue != channelSourceOp) {
|
||||
Operation* definingOp = currentValue.getDefiningOp();
|
||||
if (!definingOp || !isChannelUseChainOp(definingOp))
|
||||
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
|
||||
|
||||
clonedChain.push_back(definingOp);
|
||||
currentValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(channelSourceOp, receivedValue);
|
||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
markOpToRemove(op);
|
||||
}
|
||||
|
||||
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||
}
|
||||
|
||||
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
|
||||
blockArg.replaceAllUsesWith(replacementValue);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
|
||||
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& channelTensorType,
|
||||
bool& useBroadcastOp,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter) {
|
||||
auto sourceOpUses = channelSourceOp.getUses();
|
||||
|
||||
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
|
||||
if (useBroadcastOp == false) {
|
||||
// if useBroadcastOp is false, then sourceOp must have only one user
|
||||
assert(rangeLength(sourceOpUses) == 1);
|
||||
|
||||
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
|
||||
auto reshapeOpUses = reshapeOp.getOutput().getUses();
|
||||
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
|
||||
if (reshapeOpUsesCount > 1)
|
||||
useBroadcastOp = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& resultUse : sourceOpUses) {
|
||||
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
|
||||
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
|
||||
|
||||
if (computeUser) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!computeUser) {
|
||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
||||
if (!reshapeOp) {
|
||||
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
|
||||
resultUse.getOwner()->dump();
|
||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
||||
}
|
||||
|
||||
// The tensorType now becomes the one of the reshapeOp
|
||||
channelTensorType = reshapeOp.getResult().getType();
|
||||
|
||||
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
|
||||
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
|
||||
|
||||
if (!computeUser)
|
||||
llvm_unreachable("ReshapeOp users must be ComputeOps");
|
||||
|
||||
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
||||
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove the reshapeOp, so that the sourceOp has no users
|
||||
operationsToRemove.push_back(reshapeOp);
|
||||
if (!isChannelUseChainOp(owner))
|
||||
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
|
||||
|
||||
markOpToRemove(owner);
|
||||
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||
self(owner->getResult(0), self);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
if (!llvm::is_contained(operationsToRemove, op))
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
@@ -527,15 +565,10 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
||||
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
auto tensorType = receiveOp.getType();
|
||||
Value receiveRes = receiveOp.getResult();
|
||||
|
||||
// Check if the receiveOp value has more than one user
|
||||
auto receiveUses = receiveRes.getUses();
|
||||
auto receiveUsesCount = rangeLength(receiveUses);
|
||||
assert(receiveUsesCount > 0);
|
||||
bool useBroadcastOp = receiveUsesCount > 1;
|
||||
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
|
||||
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
|
||||
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
|
||||
|
||||
if (useBroadcastOp) {
|
||||
// When receiving, we actually noticed that the value has more than one
|
||||
|
||||
Reference in New Issue
Block a user