#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include #include #include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "Conversion/SpatialToPim/PhaseVerification.hpp" #include "Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "Dialect/Pim/PimOps.hpp" #include "Dialect/Spatial/SpatialOps.hpp" #include "Pass/PIMPasses.h" #include "SpatialToPimPass.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { namespace raptor { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" } // namespace raptor static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType(); auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType())); for (auto globalOp : moduleOp.getOps()) { if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue()) continue; if (dyn_cast(*globalOp.getInitialValue()) == zeroAttr) return globalOp; } std::string nameStem; llvm::raw_string_ostream nameStream(nameStem); nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements(); nameStream.flush(); std::string symbolName = nameStem; unsigned suffix = 0; while (SymbolTable::lookupSymbolIn(moduleOp, symbolName)) symbolName = (nameStem + "_" + Twine(suffix++)).str(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); return memref::GlobalOp::create(rewriter, loc, rewriter.getStringAttr(symbolName), rewriter.getStringAttr("private"), TypeAttr::get(memRefType), zeroAttr, rewriter.getUnitAttr(), IntegerAttr {}); } static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType, OperationFolder& constantFolder) { auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); if (outputBuffer->getParentOfType()) return PimMemCopyHostToDevBatchOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), sizeAttr) .getOutput(); return PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr) .getOutput(); } static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { auto vectorType = cast(vector.getType()); ArrayRef shape = vectorType.getShape(); assert(isHVectorShape(shape) && "expected a horizontal vector"); assert(shape[1] <= static_cast(crossbarSize) && "vector width must fit in one crossbar"); if (shape[1] == static_cast(crossbarSize)) return vector; auto paddedType = RankedTensorType::get( {shape[0], static_cast(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); auto zeroAttr = rewriter.getI32IntegerAttr(0); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(vectorType))); return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); } void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { coreId = 0; outputTensors.clear(); operationsToRemove.clear(); ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering"); signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); OperationFolder constantFolder(&getContext()); ConversionTarget target(*ctx); target.addLegalDialect(); target.addLegalOp(); RewritePatternSet initialPatterns(ctx); populateWithGenerated(initialPatterns); if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) { moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form"); signalPassFailure(); return; } RewritePatternSet globalTensorPatterns(ctx); populateGlobalTensorMaterializationPatterns(globalTensorPatterns); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); auto returnOp = cast(funcOp.front().getTerminator()); addReturnOutputBuffers(returnOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering"); signalPassFailure(); return; } for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) { computeOp.emitOpError("failed to lower spat.compute to pim.core"); signalPassFailure(); return; } } for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) { computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch"); signalPassFailure(); return; } } RewritePatternSet initialTensorPackingPatterns(ctx); populateTensorPackingPatterns(initialTensorPackingPatterns); walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns)); eraseUnusedTensorPackingOps(funcOp, rewriter); SmallVector receiveOps; for (auto op : funcOp.getOps()) receiveOps.push_back(op); for (auto receiveOp : receiveOps) { bool onlyPendingRemovalUsers = llvm::all_of( receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); }); if (onlyPendingRemovalUsers) { markOpToRemove(receiveOp); continue; } } RewritePatternSet coreBodyPatterns(ctx); populateWithGenerated(coreBodyPatterns); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); SmallVector coreOps; funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); for (auto coreOp : coreOps) { if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core"); signalPassFailure(); return; } } SmallVector coreBatchOps; funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); for (auto coreBatchOp : coreBatchOps) { if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch"); signalPassFailure(); return; } } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnWithOutputBuffers(returnOp, rewriter); eraseOpsToRemove(); RewritePatternSet finalTensorPackingPatterns(ctx); populateTensorPackingPatterns(finalTensorPackingPatterns); walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns)); eraseUnusedTensorPackingOps(funcOp, rewriter); ConversionTarget communicationTarget(*ctx); communicationTarget.addLegalDialect(); communicationTarget.addLegalOp(); communicationTarget.addIllegalOp(); RewritePatternSet communicationPatterns(ctx); populateChannelLoweringPatterns(communicationPatterns); if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops"); signalPassFailure(); return; } if (failed(verifySpatialToPimBoundary(moduleOp))) { moduleOp.emitError("Spatial-to-PIM boundary verification failed"); signalPassFailure(); return; } // Dump to file for debug dumpModule(moduleOp, "pim0"); } void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { OperationFolder constantFolder(funcOp.getContext()); funcOp.walk([&](PimVMMOp vmmOp) { auto outputType = cast(vmmOp.getOutput().getType()); ArrayRef outputShape = outputType.getShape(); assert(isHVectorShape(outputShape) && "expected a horizontal vector output"); assert(outputShape[1] <= static_cast(crossbarSize) && "output width must fit in one crossbar"); rewriter.setInsertionPoint(vmmOp); Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); auto paddedOutputType = RankedTensorType::get( {outputShape[0], static_cast(crossbarSize)}, outputType.getElementType(), outputType.getEncoding()); Value paddedOutputBuffer = outputShape[1] == static_cast(crossbarSize) ? vmmOp.getOutputBuffer() : createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult(); vmmOp.getInputMutable().assign(paddedInput); vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer); vmmOp.getOutput().setType(paddedOutputType); if (outputShape[1] == static_cast(crossbarSize)) return; SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.setInsertionPointAfter(vmmOp); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions); }); } LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); OperationFolder constantFolder(funcOp.getContext()); auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); if (!hasByteSizedElementType(elementType)) return; size_t elementByteSize = getElementTypeSizeInBytes(elementType); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder), getOrCreateHostIndexConstant( deviceTensor.getOperation(), static_cast(elementsOffset * elementByteSize), constantFolder), deviceTensor, inputTensor, rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) continue; for (auto getGlobal : computeOp.getOps()) { if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) { assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); insertMemCopyHostToDev(toTensorOpValue, 0); } } } return success(); } void raptor::SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } void raptor::SpatialToPimPass::eraseOpsToRemove() { for (Operation* op : operationsToRemove) { op->dropAllUses(); op->erase(); } } std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir