#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include #include #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { namespace { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" struct SpatialToPimPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } SpatialToPimPass() = default; SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; private: SmallVector outputTensors; size_t coreId = 0; SmallVector operationsToRemove; void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void addReceiveOps(Value& channelSourceOp, spatial::SpatChannelNewOp& channel, Type& channelTensorType, bool& useBroadcastOp, IRRewriter& rewriter); void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, unsigned int argIndex, spatial::SpatChannelNewOp& channel, Type& tensorType, bool useBroadcastOp, IRRewriter& rewriter); void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); }; } // namespace void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); ConversionTarget target(*ctx); target.addLegalDialect(); RewritePatternSet patterns(ctx); populateWithGenerated(patterns); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); auto returnOp = cast(funcOp.front().getTerminator()); addResultBuffer(returnOp, rewriter); allocateAndInitializeCoreLocalVariables(funcOp, rewriter); for (auto receiveOp : funcOp.getOps()) { operationsToRemove.push_back(receiveOp); runOnReceiveOp(receiveOp, rewriter); } for (auto computeOp : funcOp.getOps()) { operationsToRemove.push_back(computeOp); runOnComputeOp(computeOp, rewriter); } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnOpOperands(returnOp, rewriter); // Remove all ComputeOps for (auto opToRemove : llvm::reverse(operationsToRemove)) { if (!opToRemove->use_empty()) { opToRemove->dump(); for (auto user : opToRemove->getUsers()) user->dump(); assert(false && "opToRemove should be unused at this point"); } rewriter.eraseOp(opToRemove); } // Dump to file for debug dumpModule(moduleOp, "pim"); } void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); if (computeOp.getNumResults() != yieldOp.getNumOperands()) llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { // If this result has no uses, then just skip it if (result.use_empty()) continue; auto yieldType = cast(yieldValue.getType()); /* * Here we assume that ReturnOp are only reachable by the following patterns: * * 1) * %0 = spat.compute([...]) * [%0 has one user, which is a ConcatOp] * %1 = tensor.concat(%0) * [%1 has one user, which is a ReturnOp] * return %1 * * 2) * %0 = spat.compute([...]) * [%0 has one user, which is a ReturnOp] * return %0 * * If the IR is like 2), then we can store the tensor to the output global memory location */ auto resultUses = result.getUses(); auto numResultUses = rangeLength(resultUses); if (numResultUses == 1) { OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t offset = 0; size_t numElements = yieldType.getNumElements(); size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; // Store to global memory Value outputTensor = outputTensors[resultIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, yieldValue, rewriter.getI32IntegerAttr(offset), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(numElements * elementSize)); continue; } if (isa(resultUser) || isa(resultUser)) { auto concatOp = resultUser; auto concatValue = concatOp->getResult(0); auto concatUses = concatValue.getUses(); auto numConcatUses = rangeLength(concatUses); if (numConcatUses == 1) { OpOperand& concatUse = *concatUses.begin(); Operation* concatUser = concatUse.getOwner(); if (isa(concatUser)) { size_t concatIndexInReturn = concatUse.getOperandNumber(); size_t resultIndexInConcat = resultUses.begin()->getOperandNumber(); size_t offset = 0; for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat)) offset += cast(operand.getType()).getNumElements() * cast(operand.getType()).getElementTypeBitWidth() / 8; size_t elementSize = yieldType.getElementTypeBitWidth() / 8; // Store to global memory Value outputTensor = outputTensors[concatIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, yieldValue, rewriter.getI32IntegerAttr(offset), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize)); continue; } } } } // If this pattern was not found, then create a channel and send the value // 1. Create a new ChannelOp rewriter.setInsertionPoint(computeOp); auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); // 2. Receive value through the channel // If this result is used by more than one user, then use a "Broadcast" // channel operation. However, there is a special case: we have a single // user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this // case, we need to use a "Broadcast" channel operation. `addReceiveOps` // will detect this case and update `useBroadcastOp` accordingly. bool useBroadcastOp = (numResultUses > 1); addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter); // 3. Send the value through the channel rewriter.setInsertionPointAfterValue(yieldValue); if (useBroadcastOp) spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue); else spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue); } // Use `HaltOp` instead of `YieldOp` rewriter.setInsertionPoint(yieldOp); rewriter.replaceOpWithNewOp(yieldOp); // Replace `spat.compute` with `pim.core` rewriter.setInsertionPointAfter(computeOp); auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); auto& coreOpBlocks = coreOp.getBody().getBlocks(); block.eraseArguments(0, block.getNumArguments()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); computeOp.getBody().push_back(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock); PimHaltOp::create(rewriter, computeOp.getLoc()); } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); if (!definingOp) return; auto dpsDefiningOp = dyn_cast(definingOp); if (!dpsDefiningOp) return; auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); if (!tiedOperand) return; Value tiedValue = tiedOperand->get(); assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); tiedValue.setType(newType); self(tiedValue, newType, self); }; funcOp.walk([&](PimVMMOp vmmOp) { auto outTensorOperand = vmmOp.getOutBuf(); auto resultTensor = vmmOp.getOutRes(); auto outShape = getTensorShape(outTensorOperand); assert(isHVectorShape(outShape)); if (outShape[1] != static_cast(crossbarSize)) { auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); outTensorOperand.setType(newType); resultTensor.setType(newType); IntegerAttr zeroAttr = rewriter.getIndexAttr(0); IntegerAttr oneAttr = rewriter.getIndexAttr(1); IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); SmallVector offsets = {zeroAttr, zeroAttr}; SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector strides = {oneAttr, oneAttr}; rewriter.setInsertionPointAfter(vmmOp); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); } }); } void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); rewriter.setInsertionPointToStart(returnOp->getBlock()); for (auto returnValue : returnOp->getOperands()) { Operation* returnValueDefiningOp = returnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { assert(!hasWeightAlways(returnValueDefiningOp)); outputTensors.push_back(returnValue); } else { auto newOutputTensor = createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast(returnValue.getType())); outputTensors.push_back(newOutputTensor); } } } void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { auto tensorType = cast(valueToReplace.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, deviceTensor, hostTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); }; // Replace input tensors with memRefs SmallVector inputTensors; for (size_t i = 0; i < funcOp.getNumArguments(); i++) { BlockArgument tensorArg = funcOp.getArgument(i); DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i); ShapedType tensorArgType = cast(tensorArg.getType()); MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc); BlockArgument memRefArg = funcOp.getArgument(i + 1); Block& block = funcOp.getBody().front(); rewriter.setInsertionPoint(&block.front()); auto toTensorOp = bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); inputTensors.push_back(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp); funcOp.eraseArgument(i); } llvm::SmallSet sliceOpsToRemove; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { unsigned numComputeWeights = computeOp.getWeights().size(); for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { TypedValue tensorSource; int64_t elementsOffset = 0; if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { tensorSource = cast>(sliceOp.getSource()); ArrayRef sourceShape = tensorSource.getType().getShape(); ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); ArrayRef sliceSizes = sliceOp.getStaticSizes(); ArrayRef sliceStrides = sliceOp.getStaticStrides(); assert("Extracting slice non-contiguous in memory" && isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides)); for (size_t i = 0; i < sliceOffsets.size(); i++) { int64_t partialOffset = sliceOffsets[i]; if (partialOffset != 0) for (size_t j = i + 1; j < sourceShape.size(); j++) partialOffset *= sourceShape[j]; elementsOffset += partialOffset; } computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource); sliceOpsToRemove.insert(sliceOp); } else tensorSource = cast>(computeOpInput); // Compute results must be transferred through channels via send/receive if (isa(tensorSource.getDefiningOp())) continue; BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset); } } for (auto sliceOp : sliceOpsToRemove) if (sliceOp->getUses().empty()) rewriter.eraseOp(sliceOp); } void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, unsigned int argIndex, spatial::SpatChannelNewOp& channel, Type& tensorType, bool useBroadcastOp, IRRewriter& rewriter) { auto& computeBlock = computeOp.getRegion().front(); //(remember that WeightedCompute have weights as first operands, however these // weights are not included in the block arguments. Thus, when indexing the // block argument we need to remove the weights count) auto computeWeightsCount = computeOp.getWeights().size(); auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount); // Receive the tensor just before the first use of the value rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); Value receivedValue; if (useBroadcastOp) receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); else receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); blockArg.replaceAllUsesWith(receivedValue); } void SpatialToPimPass::addReceiveOps(Value& channelSourceOp, spatial::SpatChannelNewOp& channel, Type& channelTensorType, bool& useBroadcastOp, IRRewriter& rewriter) { auto sourceOpUses = channelSourceOp.getUses(); // Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users if (useBroadcastOp == false) { // if useBroadcastOp is false, then sourceOp must have only one user assert(rangeLength(sourceOpUses) == 1); if (auto reshapeOp = dyn_cast(sourceOpUses.begin()->getOwner())) { auto reshapeOpUses = reshapeOp.getOutput().getUses(); auto reshapeOpUsesCount = rangeLength(reshapeOpUses); if (reshapeOpUsesCount > 1) useBroadcastOp = true; } } for (auto& resultUse : sourceOpUses) { // The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps spatial::SpatWeightedCompute computeUser = dyn_cast(resultUse.getOwner()); if (computeUser) { replaceBlockArgumentWithRecvOp( computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); continue; } if (!computeUser) { auto reshapeOp = dyn_cast(resultUse.getOwner()); if (!reshapeOp) { channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump(); resultUse.getOwner()->dump(); llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); } // The tensorType now becomes the one of the reshapeOp channelTensorType = reshapeOp.getResult().getType(); for (auto& reshapeUse : reshapeOp.getOutput().getUses()) { computeUser = dyn_cast(reshapeUse.getOwner()); if (!computeUser) llvm_unreachable("ReshapeOp users must be ComputeOps"); replaceBlockArgumentWithRecvOp( computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); } // Remove the reshapeOp, so that the sourceOp has no users operationsToRemove.push_back(reshapeOp); } } } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { for (auto it : llvm::enumerate(returnOp.getOperands())) { Operation* returnOperand = it.value().getDefiningOp(); size_t orderWithinReturn = it.index(); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); // If the operand is a concatenation operation and the returnOp was the only // user of the returnOperand, we can safely remove it if (isAConcatOp(returnOperand)) { auto returnOperandUses = it.value().getUses(); if (rangeLength(returnOperandUses) == 0) rewriter.eraseOp(returnOperand); } } } void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { auto channel = cast(receiveOp.getChannel().getDefiningOp()); auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter); if (failed(sendOpOpt)) llvm_unreachable("ChannelReceiveOp has no matching SendOp"); auto sendOp = cast(*sendOpOpt); auto tensorType = receiveOp.getType(); Value receiveRes = receiveOp.getResult(); // Check if the receiveOp value has more than one user auto receiveUses = receiveRes.getUses(); auto receiveUsesCount = rangeLength(receiveUses); assert(receiveUsesCount > 0); bool useBroadcastOp = receiveUsesCount > 1; addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter); if (useBroadcastOp) { // When receiving, we actually noticed that the value has more than one // user. This means that we need to get the replace the original SendOp with // a BroadcastSendOp rewriter.setInsertionPoint(sendOp); rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getData()); } } std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir