#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include #include #include "Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { namespace { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" struct SpatialToPimPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } SpatialToPimPass() = default; SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; private: SmallVector outputTensors; size_t coreId = 0; SmallVector operationsToRemove; void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter); void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp, unsigned int argIndex, Value channelSourceOp, Value consumerValue, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); void annotateChannelCoreIds(func::FuncOp funcOp); void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter); void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); }; } // namespace static bool isChannelUseChainOp(Operation* op) { return isa(op); } static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { for (Value operand : op->getOperands()) { if (mapping.lookupOrNull(operand)) continue; Operation* definingOp = operand.getDefiningOp(); if (!definingOp) continue; if (!isa(definingOp)) continue; Operation* clonedOp = rewriter.clone(*definingOp, mapping); for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); rewriter.setInsertionPointAfter(clonedOp); } } static size_t countComputeLeafUsers(Value value) { size_t leafUserCount = 0; auto walkUses = [&](Value currentValue, auto& self) -> void { for (OpOperand& use : currentValue.getUses()) { Operation* owner = use.getOwner(); if (isa(owner)) { leafUserCount++; continue; } if (!isChannelUseChainOp(owner)) llvm_unreachable("Channel use chain contains unsupported op"); assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result"); self(owner->getResult(0), self); } }; walkUses(value, walkUses); return leafUserCount; } void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); ConversionTarget target(*ctx); target.addLegalDialect(); RewritePatternSet patterns(ctx); populateWithGenerated(patterns); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); auto returnOp = cast(funcOp.front().getTerminator()); addResultBuffer(returnOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { signalPassFailure(); return; } for (auto receiveOp : funcOp.getOps()) { markOpToRemove(receiveOp); runOnReceiveOp(receiveOp, rewriter); } for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); runOnComputeOp(computeOp, rewriter); } annotateChannelCoreIds(funcOp); lowerBroadcastChannelOps(funcOp, rewriter); RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { signalPassFailure(); return; } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnOpOperands(returnOp, rewriter); SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); while (!pendingRemovals.empty()) { bool erasedAnyOp = false; for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) { Operation* opToRemove = *it; if (!opToRemove->use_empty()) { ++it; continue; } rewriter.eraseOp(opToRemove); it = pendingRemovals.erase(it); erasedAnyOp = true; } if (erasedAnyOp) continue; for (auto opToRemove : pendingRemovals) { opToRemove->dump(); for (auto user : opToRemove->getUsers()) user->dump(); } assert(false && "tracked op removal reached a cycle or missed dependency"); } // Dump to file for debug dumpModule(moduleOp, "pim0"); } void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); if (computeOp.getNumResults() != yieldOp.getNumOperands()) llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { if (result.use_empty()) continue; auto yieldType = cast(yieldValue.getType()); auto resultUses = result.getUses(); auto numResultUses = rangeLength(resultUses); if (numResultUses == 1) { OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); if (isChannelUseChainOp(resultUser)) { SmallVector returnChain; Value chainedValue = result; Operation* chainUser = resultUser; while (isChannelUseChainOp(chainUser)) { returnChain.push_back(chainUser); auto chainUses = chainUser->getResult(0).getUses(); if (rangeLength(chainUses) != 1) break; chainedValue = chainUser->getResult(0); chainUser = chainUses.begin()->getOwner(); } if (isa(chainUser)) { size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber(); rewriter.setInsertionPoint(yieldOp); IRMapping mapping; mapping.map(result, yieldValue); Value storedValue = yieldValue; for (Operation* op : returnChain) { cloneMappedHelperOperands(op, mapping, rewriter); Operation* clonedOp = rewriter.clone(*op, mapping); for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); storedValue = clonedOp->getResult(0); rewriter.setInsertionPointAfter(clonedOp); markOpToRemove(op); } auto storedType = cast(storedValue.getType()); size_t elementSize = storedType.getElementTypeBitWidth() / 8; Value outputTensor = outputTensors[resultIndexInReturn]; if (auto storedOp = storedValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, storedValue, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize)); continue; } } if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t offset = 0; size_t numElements = yieldType.getNumElements(); size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; // Store to global memory Value outputTensor = outputTensors[resultIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, yieldValue, rewriter.getI32IntegerAttr(offset), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(numElements * elementSize)); continue; } if (isa(resultUser)) { auto concatOp = resultUser; auto concatValue = concatOp->getResult(0); auto concatUses = concatValue.getUses(); auto numConcatUses = rangeLength(concatUses); if (numConcatUses == 1) { Value chainedValue = concatValue; Operation* concatUser = concatUses.begin()->getOwner(); while (isChannelUseChainOp(concatUser)) { auto chainUses = concatUser->getResult(0).getUses(); if (rangeLength(chainUses) != 1) break; chainedValue = concatUser->getResult(0); concatUser = chainUses.begin()->getOwner(); } if (isa(concatUser)) { size_t concatIndexInReturn = chainedValue.getUses().begin()->getOperandNumber(); size_t resultIndexInConcat = resultUses.begin()->getOperandNumber(); size_t offset = 0; for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat)) offset += cast(operand.getType()).getNumElements() * cast(operand.getType()).getElementTypeBitWidth() / 8; size_t elementSize = yieldType.getElementTypeBitWidth() / 8; // Store to global memory Value outputTensor = outputTensors[concatIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, yieldValue, rewriter.getI32IntegerAttr(offset), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize)); continue; } } } } // If this pattern was not found, then create a channel and send the value // 1. Create a new ChannelOp rewriter.setInsertionPoint(computeOp); auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); // 2. Receive value through the channel. Broadcast is needed whenever the // value eventually reaches more than one compute consumer, even through a // chain of view-like ops. bool useBroadcastOp = countComputeLeafUsers(result) > 1; addReceiveOps(result, channelOp, useBroadcastOp, rewriter); // 3. Send the value through the channel rewriter.setInsertionPointAfterValue(yieldValue); if (useBroadcastOp) spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue); else spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue); } // Use `HaltOp` instead of `YieldOp` rewriter.setInsertionPoint(yieldOp); rewriter.replaceOpWithNewOp(yieldOp); // Replace `spat.compute` with `pim.core` rewriter.setInsertionPointAfter(computeOp); auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); auto& coreOpBlocks = coreOp.getBody().getBlocks(); block.eraseArguments(0, block.getNumArguments()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); computeOp.getBody().push_back(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock); PimHaltOp::create(rewriter, computeOp.getLoc()); } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); if (!definingOp) return; auto dpsDefiningOp = dyn_cast(definingOp); if (!dpsDefiningOp) return; auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); if (!tiedOperand) return; Value tiedValue = tiedOperand->get(); assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); tiedValue.setType(newType); self(tiedValue, newType, self); }; funcOp.walk([&](PimVMMOp vmmOp) { auto outTensorOperand = vmmOp.getOutputBuffer(); auto resultTensor = vmmOp.getOutput(); auto outShape = getTensorShape(outTensorOperand); assert(isHVectorShape(outShape)); if (outShape[1] != static_cast(crossbarSize)) { auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); if (outTensorOperand == vmmOp.getInput()) { rewriter.setInsertionPoint(vmmOp); auto newOutputBuffer = tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType()); vmmOp.getOutputBufferMutable().assign(newOutputBuffer); } else { enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); outTensorOperand.setType(newType); } resultTensor.setType(newType); IntegerAttr zeroAttr = rewriter.getIndexAttr(0); IntegerAttr oneAttr = rewriter.getIndexAttr(1); IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); SmallVector offsets = {zeroAttr, zeroAttr}; SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector strides = {oneAttr, oneAttr}; rewriter.setInsertionPointAfter(vmmOp); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); } }); } void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); rewriter.setInsertionPointToStart(returnOp->getBlock()); for (auto returnValue : returnOp->getOperands()) { Operation* returnValueDefiningOp = returnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { assert(!hasWeightAlways(returnValueDefiningOp)); outputTensors.push_back(returnValue); } else { auto newOutputTensor = createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast(returnValue.getType())); outputTensors.push_back(newOutputTensor); } } } LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { auto tensorType = cast(valueToReplace.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, deviceTensor, hostTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); }; // Replace input tensors with memRefs SmallVector inputTensors; for (size_t i = 0; i < funcOp.getNumArguments(); i++) { BlockArgument tensorArg = funcOp.getArgument(i); DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i); ShapedType tensorArgType = cast(tensorArg.getType()); MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc))) return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering"); BlockArgument memRefArg = funcOp.getArgument(i + 1); Block& block = funcOp.getBody().front(); rewriter.setInsertionPoint(&block.front()); auto toTensorOp = bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); inputTensors.push_back(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp); if (failed(funcOp.eraseArgument(i))) return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering"); } llvm::SmallSet sliceOpsToRemove; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { unsigned numComputeWeights = computeOp.getWeights().size(); for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { TypedValue tensorSource; int64_t elementsOffset = 0; if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { tensorSource = cast>(sliceOp.getSource()); if (isa(tensorSource.getDefiningOp())) continue; ArrayRef sourceShape = tensorSource.getType().getShape(); ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); ArrayRef sliceSizes = sliceOp.getStaticSizes(); ArrayRef sliceStrides = sliceOp.getStaticStrides(); assert("Extracting slice non-contiguous in memory" && isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides)); for (size_t i = 0; i < sliceOffsets.size(); i++) { int64_t partialOffset = sliceOffsets[i]; if (partialOffset != 0) for (size_t j = i + 1; j < sourceShape.size(); j++) partialOffset *= sourceShape[j]; elementsOffset += partialOffset; } computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource); sliceOpsToRemove.insert(sliceOp); } else tensorSource = cast>(computeOpInput); // Compute results must be transferred through channels via send/receive if (isa(tensorSource.getDefiningOp())) continue; BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset); } } for (auto sliceOp : sliceOpsToRemove) if (sliceOp->getUses().empty()) rewriter.eraseOp(sliceOp); return success(); } void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp, unsigned int argIndex, Value channelSourceOp, Value consumerValue, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter) { auto& computeBlock = computeOp.getRegion().front(); //(remember that WeightedCompute have weights as first operands, however these // weights are not included in the block arguments. Thus, when indexing the // block argument we need to remove the weights count) auto computeWeightsCount = computeOp.getWeights().size(); auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount); // Receive the tensor just before the first use of the value rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); Value receivedValue; if (useBroadcastOp) receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); else receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); Value replacementValue = receivedValue; if (consumerValue != channelSourceOp) { SmallVector clonedChain; Value currentValue = consumerValue; while (currentValue != channelSourceOp) { Operation* definingOp = currentValue.getDefiningOp(); if (!definingOp || !isChannelUseChainOp(definingOp)) llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute"); clonedChain.push_back(definingOp); currentValue = definingOp->getOperand(0); } IRMapping mapping; mapping.map(channelSourceOp, receivedValue); for (Operation* op : llvm::reverse(clonedChain)) { cloneMappedHelperOperands(op, mapping, rewriter); Operation* clonedOp = rewriter.clone(*op, mapping); for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); markOpToRemove(op); } replacementValue = cast(mapping.lookup(consumerValue)); } assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type"); blockArg.replaceAllUsesWith(replacementValue); } void SpatialToPimPass::addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter) { auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void { for (OpOperand& use : currentValue.getUses()) { Operation* owner = use.getOwner(); if (auto computeUser = dyn_cast(owner)) { replaceBlockArgumentWithRecvOp( computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter); continue; } if (!isChannelUseChainOp(owner)) llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op"); markOpToRemove(owner); assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result"); self(owner->getResult(0), self); } }; replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers); } void SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) { funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) { markOpToRemove(channelNewOp); if (channelNewOp->use_empty()) return; spatial::SpatChannelSendOp sendOp; spatial::SpatChannelReceiveOp receiveOp; spatial::SpatChannelBroadcastSendOp broadcastSendOp; for (Operation* user : channelNewOp->getUsers()) { if (auto op = dyn_cast(user)) { sendOp = op; continue; } if (auto op = dyn_cast(user)) { receiveOp = op; continue; } if (auto op = dyn_cast(user)) { broadcastSendOp = op; continue; } if (auto op = dyn_cast(user)) continue; llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering"); } if (broadcastSendOp) { auto sourceCoreIdAttr = cast(broadcastSendOp->getParentOp()).getCoreIdAttr(); channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr); return; } if (!sendOp || !receiveOp) llvm_unreachable("spat.channel_new must connect exactly one send and one receive"); auto sourceCoreIdAttr = cast(sendOp->getParentOp()).getCoreIdAttr(); auto targetCoreIdAttr = cast(receiveOp->getParentOp()).getCoreIdAttr(); channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr); channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr); }); } void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector broadcastSendOps; funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); }); for (auto sendOp : broadcastSendOps) { auto channelNewOp = cast(sendOp.getChannel().getDefiningOp()); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); rewriter.setInsertionPoint(sendOp); bool foundReceiver = false; for (Operation* user : channelNewOp->getUsers()) { auto receiveOp = dyn_cast(user); if (!receiveOp) continue; foundReceiver = true; auto targetCoreIdAttr = cast(receiveOp->getParentOp()).getCoreIdAttr(); PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); } if (!foundReceiver) llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive"); rewriter.eraseOp(sendOp); } SmallVector broadcastReceiveOps; funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); }); for (auto receiveOp : broadcastReceiveOps) { rewriter.setInsertionPoint(receiveOp); auto outputType = cast(receiveOp.getResult().getType()); Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel()); Value receivedValue = PimReceiveOp::create( rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); rewriter.replaceOp(receiveOp, receivedValue); } } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); Operation* returnOperand = it.value().getDefiningOp(); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); Operation* opToErase = returnOperand; while (opToErase) { bool isExclusivelyOwnedByReturnChain = opToErase->use_empty(); if (!isExclusivelyOwnedByReturnChain && opToErase->hasOneUse()) { Operation* onlyUser = *opToErase->getUsers().begin(); isExclusivelyOwnedByReturnChain = isa(onlyUser) || isChannelUseChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) break; if (isChannelUseChainOp(opToErase)) { Value source = opToErase->getOperand(0); markOpToRemove(opToErase); opToErase = source.getDefiningOp(); continue; } if (isa(opToErase)) markOpToRemove(opToErase); break; } } } void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { auto channel = cast(receiveOp.getChannel().getDefiningOp()); auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter); if (failed(sendOpOpt)) llvm_unreachable("ChannelReceiveOp has no matching SendOp"); auto sendOp = cast(*sendOpOpt); Value receiveRes = receiveOp.getResult(); bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1; addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter); if (useBroadcastOp) { // When receiving, we actually noticed that the value has more than one // user. This means that we need to get the replace the original SendOp with // a BroadcastSendOp rewriter.setInsertionPoint(sendOp); rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getInput()); } } std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir