simple convolutions now work :)
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
#include "mlir/IR/BuiltinDialect.h"
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
@@ -52,20 +53,21 @@ private:
|
|||||||
|
|
||||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||||
void addReceiveOps(Value& channelSourceOp,
|
void addReceiveOps(Value channelSourceOp,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& channelTensorType,
|
bool useBroadcastOp,
|
||||||
bool& useBroadcastOp,
|
|
||||||
IRRewriter& rewriter);
|
IRRewriter& rewriter);
|
||||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
|
Value channelSourceOp,
|
||||||
|
Value consumerValue,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& tensorType,
|
|
||||||
bool useBroadcastOp,
|
bool useBroadcastOp,
|
||||||
IRRewriter& rewriter);
|
IRRewriter& rewriter);
|
||||||
|
void markOpToRemove(Operation* op);
|
||||||
|
|
||||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
@@ -76,6 +78,34 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static bool isChannelUseChainOp(Operation* op) {
|
||||||
|
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||||
|
op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t countComputeLeafUsers(Value value) {
|
||||||
|
size_t leafUserCount = 0;
|
||||||
|
|
||||||
|
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||||
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
|
Operation* owner = use.getOwner();
|
||||||
|
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||||
|
leafUserCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isChannelUseChainOp(owner))
|
||||||
|
llvm_unreachable("Channel use chain contains unsupported op");
|
||||||
|
|
||||||
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||||
|
self(owner->getResult(0), self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
walkUses(value, walkUses);
|
||||||
|
return leafUserCount;
|
||||||
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnOperation() {
|
void SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 1;
|
coreId = 1;
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
@@ -103,7 +133,10 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
|
|
||||||
addResultBuffer(returnOp, rewriter);
|
addResultBuffer(returnOp, rewriter);
|
||||||
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||||
operationsToRemove.push_back(receiveOp);
|
operationsToRemove.push_back(receiveOp);
|
||||||
@@ -233,14 +266,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||||
|
|
||||||
// 2. Receive value through the channel
|
// 2. Receive value through the channel. Broadcast is needed whenever the
|
||||||
// If this result is used by more than one user, then use a "Broadcast"
|
// value eventually reaches more than one compute consumer, even through a
|
||||||
// channel operation. However, there is a special case: we have a single
|
// chain of view-like ops.
|
||||||
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
|
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
|
||||||
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
|
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
|
||||||
// will detect this case and update `useBroadcastOp` accordingly.
|
|
||||||
bool useBroadcastOp = (numResultUses > 1);
|
|
||||||
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
|
|
||||||
|
|
||||||
// 3. Send the value through the channel
|
// 3. Send the value through the channel
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
@@ -327,7 +357,7 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||||
@@ -359,7 +389,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||||
|
|
||||||
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
|
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
||||||
|
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
||||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||||
|
|
||||||
Block& block = funcOp.getBody().front();
|
Block& block = funcOp.getBody().front();
|
||||||
@@ -369,7 +400,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
inputTensors.push_back(toTensorOp);
|
inputTensors.push_back(toTensorOp);
|
||||||
|
|
||||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||||
funcOp.eraseArgument(i);
|
if (failed(funcOp.eraseArgument(i)))
|
||||||
|
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||||
@@ -383,6 +415,9 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||||
|
|
||||||
|
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||||
|
continue;
|
||||||
|
|
||||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||||
@@ -416,12 +451,15 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
for (auto sliceOp : sliceOpsToRemove)
|
for (auto sliceOp : sliceOpsToRemove)
|
||||||
if (sliceOp->getUses().empty())
|
if (sliceOp->getUses().empty())
|
||||||
rewriter.eraseOp(sliceOp);
|
rewriter.eraseOp(sliceOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
|
Value channelSourceOp,
|
||||||
|
Value consumerValue,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& tensorType,
|
|
||||||
bool useBroadcastOp,
|
bool useBroadcastOp,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
auto& computeBlock = computeOp.getRegion().front();
|
auto& computeBlock = computeOp.getRegion().front();
|
||||||
@@ -434,68 +472,68 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||||
Value receivedValue;
|
Value receivedValue;
|
||||||
if (useBroadcastOp)
|
if (useBroadcastOp)
|
||||||
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
receivedValue =
|
||||||
|
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
else
|
else
|
||||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
|
|
||||||
blockArg.replaceAllUsesWith(receivedValue);
|
Value replacementValue = receivedValue;
|
||||||
|
if (consumerValue != channelSourceOp) {
|
||||||
|
SmallVector<Operation*> clonedChain;
|
||||||
|
Value currentValue = consumerValue;
|
||||||
|
while (currentValue != channelSourceOp) {
|
||||||
|
Operation* definingOp = currentValue.getDefiningOp();
|
||||||
|
if (!definingOp || !isChannelUseChainOp(definingOp))
|
||||||
|
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
|
||||||
|
|
||||||
|
clonedChain.push_back(definingOp);
|
||||||
|
currentValue = definingOp->getOperand(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
|
IRMapping mapping;
|
||||||
|
mapping.map(channelSourceOp, receivedValue);
|
||||||
|
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
markOpToRemove(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
|
||||||
|
blockArg.replaceAllUsesWith(replacementValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& channelTensorType,
|
bool useBroadcastOp,
|
||||||
bool& useBroadcastOp,
|
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
auto sourceOpUses = channelSourceOp.getUses();
|
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||||
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
|
Operation* owner = use.getOwner();
|
||||||
if (useBroadcastOp == false) {
|
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||||
// if useBroadcastOp is false, then sourceOp must have only one user
|
|
||||||
assert(rangeLength(sourceOpUses) == 1);
|
|
||||||
|
|
||||||
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
|
|
||||||
auto reshapeOpUses = reshapeOp.getOutput().getUses();
|
|
||||||
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
|
|
||||||
if (reshapeOpUsesCount > 1)
|
|
||||||
useBroadcastOp = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& resultUse : sourceOpUses) {
|
|
||||||
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
|
|
||||||
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
|
|
||||||
|
|
||||||
if (computeUser) {
|
|
||||||
replaceBlockArgumentWithRecvOp(
|
replaceBlockArgumentWithRecvOp(
|
||||||
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!computeUser) {
|
if (!isChannelUseChainOp(owner))
|
||||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
|
||||||
if (!reshapeOp) {
|
|
||||||
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
|
markOpToRemove(owner);
|
||||||
resultUse.getOwner()->dump();
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
self(owner->getResult(0), self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tensorType now becomes the one of the reshapeOp
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||||
channelTensorType = reshapeOp.getResult().getType();
|
if (!llvm::is_contained(operationsToRemove, op))
|
||||||
|
operationsToRemove.push_back(op);
|
||||||
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
|
|
||||||
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
|
|
||||||
|
|
||||||
if (!computeUser)
|
|
||||||
llvm_unreachable("ReshapeOp users must be ComputeOps");
|
|
||||||
|
|
||||||
replaceBlockArgumentWithRecvOp(
|
|
||||||
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the reshapeOp, so that the sourceOp has no users
|
|
||||||
operationsToRemove.push_back(reshapeOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
@@ -527,15 +565,10 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
|||||||
|
|
||||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||||
|
|
||||||
auto tensorType = receiveOp.getType();
|
|
||||||
Value receiveRes = receiveOp.getResult();
|
Value receiveRes = receiveOp.getResult();
|
||||||
|
|
||||||
// Check if the receiveOp value has more than one user
|
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
|
||||||
auto receiveUses = receiveRes.getUses();
|
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
|
||||||
auto receiveUsesCount = rangeLength(receiveUses);
|
|
||||||
assert(receiveUsesCount > 0);
|
|
||||||
bool useBroadcastOp = receiveUsesCount > 1;
|
|
||||||
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
|
|
||||||
|
|
||||||
if (useBroadcastOp) {
|
if (useBroadcastOp) {
|
||||||
// When receiving, we actually noticed that the value has more than one
|
// When receiving, we actually noticed that the value has more than one
|
||||||
|
|||||||
Reference in New Issue
Block a user