#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { namespace { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" struct SpatialToPimPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } SpatialToPimPass() = default; SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; private: SmallVector outputTensors; size_t coreId = 0; SmallVector operationsToRemove; LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); }; } // namespace static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) { auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId())); rewriter.setInsertionPoint(sendOp); PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); rewriter.eraseOp(sendOp); } static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { if (receiveOp->use_empty()) { rewriter.eraseOp(receiveOp); return; } auto outputType = cast(receiveOp.getResult().getType()); rewriter.setInsertionPoint(receiveOp); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); Value received = PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); rewriter.replaceOp(receiveOp, received); } static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) { SmallVector targetCoreIds; targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size()); for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds()) targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); rewriter.setInsertionPoint(sendTensorOp); PimSendTensorOp::create( rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.eraseOp(sendTensorOp); } static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) { SmallVector sourceCoreIds; sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size()); for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds()) sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); rewriter.setInsertionPoint(receiveTensorOp); auto outputType = cast(receiveTensorOp.getOutput().getType()); Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult(); Value received = PimReceiveTensorOp::create(rewriter, receiveTensorOp.getLoc(), receiveTensorOp.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds)) .getOutput(); rewriter.replaceOp(receiveTensorOp, received); } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(extractRowsOp); auto inputType = cast(extractRowsOp.getInput().getType()); SmallVector replacements; replacements.reserve(extractRowsOp.getNumResults()); for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { auto outputType = cast(output.getType()); SmallVector offsets = { rewriter.getIndexAttr(static_cast(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)), rewriter.getIndexAttr(inputType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; replacements.push_back( tensor::ExtractSliceOp::create( rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides) .getResult()); } rewriter.replaceOp(extractRowsOp, replacements); } static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector concatOps; funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); for (auto concatOp : concatOps) { if (concatOp.getAxis() != 0 || concatOp.getInputs().empty()) continue; SmallVector packedInputs; bool changed = false; rewriter.setInsertionPoint(concatOp); for (unsigned index = 0; index < concatOp.getInputs().size();) { Value input = concatOp.getInputs()[index]; if (input.getDefiningOp()) { unsigned endIndex = index + 1; while (endIndex < concatOp.getInputs().size() && concatOp.getInputs()[endIndex].getDefiningOp()) ++endIndex; Value packedInput = createPackedExtractSliceTensor( concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc()); if (packedInput) { packedInputs.push_back(packedInput); changed = true; index = endIndex; continue; } } auto result = dyn_cast(input); if (!result) { packedInputs.push_back(input); ++index; continue; } Operation* owner = result.getOwner(); unsigned startIndex = result.getResultNumber(); unsigned endIndex = index + 1; while (endIndex < concatOp.getInputs().size()) { auto nextResult = dyn_cast(concatOp.getInputs()[endIndex]); if (!nextResult || nextResult.getOwner() != owner || nextResult.getResultNumber() != startIndex + (endIndex - index)) break; ++endIndex; } unsigned count = endIndex - index; Value packedInput; if (auto extractRowsOp = dyn_cast(owner)) packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc()); if (packedInput) { packedInputs.push_back(packedInput); changed = true; } else { for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex) packedInputs.push_back(concatOp.getInputs()[oldIndex]); } index = endIndex; } if (!changed) continue; auto newConcat = pim::PimConcatOp::create( rewriter, concatOp.getLoc(), concatOp.getOutput().getType(), concatOp.getAxisAttr(), ValueRange(packedInputs), createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast(concatOp.getOutput().getType())) .getResult()); rewriter.replaceOp(concatOp, newConcat.getOutput()); } auto eraseUnusedOps = [&](auto tag) { using OpTy = decltype(tag); SmallVector ops; funcOp.walk([&](OpTy op) { ops.push_back(op); }); for (auto op : llvm::reverse(ops)) if (op->use_empty()) rewriter.eraseOp(op); }; eraseUnusedOps(tensor::ConcatOp {}); eraseUnusedOps(tensor::ExtractSliceOp {}); eraseUnusedOps(spatial::SpatExtractRowsOp {}); } void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); ConversionTarget target(*ctx); target.addLegalDialect(); target.addLegalOp(); { RewritePatternSet patterns(ctx); populateWithGenerated(patterns); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } } { RewritePatternSet patterns(ctx); populateGlobalTensorMaterializationPatterns(patterns); walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto returnOp = cast(funcOp.front().getTerminator()); addReturnOutputBuffers(returnOp, rewriter, outputTensors); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { signalPassFailure(); return; } CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove}; for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { signalPassFailure(); return; } } for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) { signalPassFailure(); return; } } compactSpatialTensorGroups(funcOp, rewriter); SmallVector receiveOps; for (auto op : funcOp.getOps()) receiveOps.push_back(op); for (auto receiveOp : receiveOps) { bool onlyPendingRemovalUsers = llvm::all_of( receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); }); if (onlyPendingRemovalUsers) { markOpToRemove(receiveOp); continue; } if (receiveOp->use_empty()) { rewriter.eraseOp(receiveOp); continue; } lowerChannelReceive(receiveOp, rewriter); } SmallVector receiveTensorOps; for (auto op : funcOp.getOps()) receiveTensorOps.push_back(op); for (auto receiveTensorOp : receiveTensorOps) lowerChannelReceiveTensor(receiveTensorOp, rewriter); SmallVector sendOps; for (auto op : funcOp.getOps()) sendOps.push_back(op); for (auto sendOp : sendOps) lowerChannelSend(sendOp, rewriter); SmallVector sendTensorOps; for (auto op : funcOp.getOps()) sendTensorOps.push_back(op); for (auto sendTensorOp : sendTensorOps) lowerChannelSendTensor(sendTensorOp, rewriter); SmallVector extractRowsOps; for (auto op : funcOp.getOps()) extractRowsOps.push_back(op); for (auto extractRowsOp : extractRowsOps) lowerExtractRows(extractRowsOp, rewriter); { RewritePatternSet coreBodyPatterns(ctx); populateWithGenerated(coreBodyPatterns); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); SmallVector coreOps; funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); for (auto coreOp : coreOps) { if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } } SmallVector coreBatchOps; funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); for (auto coreBatchOp : coreBatchOps) { if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } } } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); ReturnPathState returnPathState {outputTensors, operationsToRemove}; replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); if (failed(erasePendingOps(pendingRemovals, rewriter))) { signalPassFailure(); return; } compactSpatialTensorGroups(funcOp, rewriter); { ConversionTarget communicationTarget(*ctx); communicationTarget.addLegalDialect(); communicationTarget.addLegalOp(); communicationTarget.addIllegalOp(); RewritePatternSet communicationPatterns(ctx); populateChannelLoweringPatterns(communicationPatterns); if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { signalPassFailure(); return; } } if (failed(verifySpatialToPimBoundary(moduleOp))) { signalPassFailure(); return; } // Dump to file for debug dumpModule(moduleOp, "pim0"); } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); if (!definingOp) return; auto dpsDefiningOp = dyn_cast(definingOp); if (!dpsDefiningOp) return; auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); if (!tiedOperand) return; Value tiedValue = tiedOperand->get(); assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); tiedValue.setType(newType); self(tiedValue, newType, self); }; funcOp.walk([&](PimVMMOp vmmOp) { auto outTensorOperand = vmmOp.getOutputBuffer(); auto resultTensor = vmmOp.getOutput(); auto outShape = getTensorShape(outTensorOperand); assert(isHVectorShape(outShape)); if (outShape[1] != static_cast(crossbarSize)) { auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); if (outTensorOperand == vmmOp.getInput()) { rewriter.setInsertionPoint(vmmOp); auto newOutputBuffer = tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType()); vmmOp.getOutputBufferMutable().assign(newOutputBuffer); } else { enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); outTensorOperand.setType(newType); } resultTensor.setType(newType); IntegerAttr zeroAttr = rewriter.getIndexAttr(0); IntegerAttr oneAttr = rewriter.getIndexAttr(1); IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); SmallVector offsets = {zeroAttr, zeroAttr}; SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector strides = {oneAttr, oneAttr}; rewriter.setInsertionPointAfter(vmmOp); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); } }); } LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, deviceTensor, inputTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) continue; for (auto getGlobal : computeOp.getOps()) { if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) { assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); insertMemCopyHostToDev(toTensorOpValue, 0); } } } return success(); } void SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir