diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index eaf2b49..7490fb2 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -4,6 +4,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" @@ -52,20 +53,21 @@ private: void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); - void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); + LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); - void addReceiveOps(Value& channelSourceOp, + void addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, - Type& channelTensorType, - bool& useBroadcastOp, + bool useBroadcastOp, IRRewriter& rewriter); void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, unsigned int argIndex, + Value channelSourceOp, + Value consumerValue, spatial::SpatChannelNewOp& channel, - Type& tensorType, bool useBroadcastOp, IRRewriter& rewriter); + void markOpToRemove(Operation* op); void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); @@ -76,6 +78,34 @@ private: } // namespace +static bool isChannelUseChainOp(Operation* op) { + return isa( + op); +} + +static size_t countComputeLeafUsers(Value value) { + size_t leafUserCount = 0; + + auto walkUses = [&](Value currentValue, auto& self) -> void { + for (OpOperand& use : currentValue.getUses()) { + Operation* owner = use.getOwner(); + if (isa(owner)) { + leafUserCount++; + continue; + } + + if (!isChannelUseChainOp(owner)) + llvm_unreachable("Channel use chain contains unsupported op"); + + assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result"); + self(owner->getResult(0), self); + } + }; + + walkUses(value, walkUses); + return leafUserCount; +} + void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); @@ -103,7 +133,10 @@ void SpatialToPimPass::runOnOperation() { auto returnOp = cast(funcOp.front().getTerminator()); addResultBuffer(returnOp, rewriter); - allocateAndInitializeCoreLocalVariables(funcOp, rewriter); + if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { + signalPassFailure(); + return; + } for (auto receiveOp : funcOp.getOps()) { operationsToRemove.push_back(receiveOp); @@ -233,14 +266,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); - // 2. Receive value through the channel - // If this result is used by more than one user, then use a "Broadcast" - // channel operation. However, there is a special case: we have a single - // user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this - // case, we need to use a "Broadcast" channel operation. `addReceiveOps` - // will detect this case and update `useBroadcastOp` accordingly. - bool useBroadcastOp = (numResultUses > 1); - addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter); + // 2. Receive value through the channel. Broadcast is needed whenever the + // value eventually reaches more than one compute consumer, even through a + // chain of view-like ops. + bool useBroadcastOp = countComputeLeafUsers(result) > 1; + addReceiveOps(result, channelOp, useBroadcastOp, rewriter); // 3. Send the value through the channel rewriter.setInsertionPointAfterValue(yieldValue); @@ -327,7 +357,7 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew } } -void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { +LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { @@ -359,7 +389,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func ShapedType tensorArgType = cast(tensorArg.getType()); MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); - funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc); + if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc))) + return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering"); BlockArgument memRefArg = funcOp.getArgument(i + 1); Block& block = funcOp.getBody().front(); @@ -369,7 +400,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func inputTensors.push_back(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp); - funcOp.eraseArgument(i); + if (failed(funcOp.eraseArgument(i))) + return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering"); } llvm::SmallSet sliceOpsToRemove; @@ -383,6 +415,9 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { tensorSource = cast>(sliceOp.getSource()); + if (isa(tensorSource.getDefiningOp())) + continue; + ArrayRef sourceShape = tensorSource.getType().getShape(); ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); ArrayRef sliceSizes = sliceOp.getStaticSizes(); @@ -416,12 +451,15 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func for (auto sliceOp : sliceOpsToRemove) if (sliceOp->getUses().empty()) rewriter.eraseOp(sliceOp); + + return success(); } void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, unsigned int argIndex, + Value channelSourceOp, + Value consumerValue, spatial::SpatChannelNewOp& channel, - Type& tensorType, bool useBroadcastOp, IRRewriter& rewriter) { auto& computeBlock = computeOp.getRegion().front(); @@ -434,68 +472,68 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); Value receivedValue; if (useBroadcastOp) - receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); + receivedValue = + spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); else - receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); + receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); - blockArg.replaceAllUsesWith(receivedValue); + Value replacementValue = receivedValue; + if (consumerValue != channelSourceOp) { + SmallVector clonedChain; + Value currentValue = consumerValue; + while (currentValue != channelSourceOp) { + Operation* definingOp = currentValue.getDefiningOp(); + if (!definingOp || !isChannelUseChainOp(definingOp)) + llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute"); + + clonedChain.push_back(definingOp); + currentValue = definingOp->getOperand(0); + } + + IRMapping mapping; + mapping.map(channelSourceOp, receivedValue); + for (Operation* op : llvm::reverse(clonedChain)) { + Operation* clonedOp = rewriter.clone(*op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + markOpToRemove(op); + } + + replacementValue = cast(mapping.lookup(consumerValue)); + } + + assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type"); + blockArg.replaceAllUsesWith(replacementValue); } -void SpatialToPimPass::addReceiveOps(Value& channelSourceOp, +void SpatialToPimPass::addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, - Type& channelTensorType, - bool& useBroadcastOp, + bool useBroadcastOp, IRRewriter& rewriter) { - auto sourceOpUses = channelSourceOp.getUses(); - - // Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users - if (useBroadcastOp == false) { - // if useBroadcastOp is false, then sourceOp must have only one user - assert(rangeLength(sourceOpUses) == 1); - - if (auto reshapeOp = dyn_cast(sourceOpUses.begin()->getOwner())) { - auto reshapeOpUses = reshapeOp.getOutput().getUses(); - auto reshapeOpUsesCount = rangeLength(reshapeOpUses); - if (reshapeOpUsesCount > 1) - useBroadcastOp = true; - } - } - - for (auto& resultUse : sourceOpUses) { - // The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps - spatial::SpatWeightedCompute computeUser = dyn_cast(resultUse.getOwner()); - - if (computeUser) { - replaceBlockArgumentWithRecvOp( - computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); - continue; - } - - if (!computeUser) { - auto reshapeOp = dyn_cast(resultUse.getOwner()); - if (!reshapeOp) { - channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump(); - resultUse.getOwner()->dump(); - llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); - } - - // The tensorType now becomes the one of the reshapeOp - channelTensorType = reshapeOp.getResult().getType(); - - for (auto& reshapeUse : reshapeOp.getOutput().getUses()) { - computeUser = dyn_cast(reshapeUse.getOwner()); - - if (!computeUser) - llvm_unreachable("ReshapeOp users must be ComputeOps"); - + auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void { + for (OpOperand& use : currentValue.getUses()) { + Operation* owner = use.getOwner(); + if (auto computeUser = dyn_cast(owner)) { replaceBlockArgumentWithRecvOp( - computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); + computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter); + continue; } - // Remove the reshapeOp, so that the sourceOp has no users - operationsToRemove.push_back(reshapeOp); + if (!isChannelUseChainOp(owner)) + llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op"); + + markOpToRemove(owner); + assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result"); + self(owner->getResult(0), self); } - } + }; + + replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers); +} + +void SpatialToPimPass::markOpToRemove(Operation* op) { + if (!llvm::is_contained(operationsToRemove, op)) + operationsToRemove.push_back(op); } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { @@ -527,15 +565,10 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I auto sendOp = cast(*sendOpOpt); - auto tensorType = receiveOp.getType(); Value receiveRes = receiveOp.getResult(); - // Check if the receiveOp value has more than one user - auto receiveUses = receiveRes.getUses(); - auto receiveUsesCount = rangeLength(receiveUses); - assert(receiveUsesCount > 0); - bool useBroadcastOp = receiveUsesCount > 1; - addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter); + bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1; + addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter); if (useBroadcastOp) { // When receiving, we actually noticed that the value has more than one