#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace llvm; using namespace mlir; namespace onnx_mlir { namespace { struct DenseWeightView { DenseElementsAttr denseAttr; SmallVector shape; SmallVector strides; int64_t offset = 0; }; FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { SmallVector viewOps; mlir::Value current = weight; memref::GetGlobalOp getGlobalOp; while (true) { Operation* defOp = current.getDefiningOp(); if (!defOp) return failure(); if ((getGlobalOp = dyn_cast(defOp))) break; if (auto subview = dyn_cast(defOp)) { if (!hasAllStaticSubviewParts(subview)) return failure(); viewOps.push_back(subview); current = subview.getSource(); continue; } if (auto cast = dyn_cast(defOp)) { current = cast.getSource(); continue; } if (auto collapse = dyn_cast(defOp)) { auto srcType = dyn_cast(collapse.getSrc().getType()); auto resultType = dyn_cast(collapse.getResult().getType()); if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); viewOps.push_back(collapse); current = collapse.getSrc(); continue; } if (auto expand = dyn_cast(defOp)) { auto srcType = dyn_cast(expand.getSrc().getType()); auto resultType = dyn_cast(expand.getResult().getType()); if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); viewOps.push_back(expand); current = expand.getSrc(); continue; } return failure(); } auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp || !globalOp.getInitialValue()) return failure(); auto denseAttr = dyn_cast(*globalOp.getInitialValue()); if (!denseAttr) return failure(); DenseWeightView view; view.denseAttr = denseAttr; view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); view.strides = computeRowMajorStrides(view.shape); for (Operation* viewOp : llvm::reverse(viewOps)) { if (auto subview = dyn_cast(viewOp)) { SmallVector nextStrides; nextStrides.reserve(subview.getStaticStrides().size()); for (auto [offset, stride, sourceStride] : llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { view.offset += offset * sourceStride; nextStrides.push_back(stride * sourceStride); } view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); view.strides = std::move(nextStrides); continue; } // Collapse/expand are accepted only as contiguous static reshapes of a // dense global view, so a row-major stride recomputation preserves layout. if (auto collapse = dyn_cast(viewOp)) { if (view.strides != computeRowMajorStrides(view.shape)) return failure(); auto resultType = cast(collapse.getResult().getType()); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); view.strides = computeRowMajorStrides(view.shape); continue; } if (auto expand = dyn_cast(viewOp)) { if (view.strides != computeRowMajorStrides(view.shape)) return failure(); auto resultType = cast(expand.getResult().getType()); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); view.strides = computeRowMajorStrides(view.shape); continue; } } return view; } SmallVector getUsedWeightIndices(Block& block) { SmallVector indices; auto coreOp = dyn_cast(block.getParentOp()); auto addWeight = [&](mlir::Value weight) { if (!coreOp) return; for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) { if (coreOp.getWeightArgument(weightIndex) != weight) continue; if (!llvm::is_contained(indices, weightIndex)) indices.push_back(weightIndex); return; } }; block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); llvm::sort(indices); return indices; } SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { return getUsedWeightIndices(coreOp.getBody().front()); } SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { SmallVector coreLikeOps; for (Operation& op : funcOp.getBody().front()) if (dyn_cast(&op) || dyn_cast(&op)) coreLikeOps.push_back(&op); return coreLikeOps; } } // namespace llvm::DenseMap> createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { ModuleOp moduleOp = funcOp->getParentOfType(); auto coreWeightsDirPath = outputDirPath + "/weights"; auto error = sys::fs::create_directory(coreWeightsDirPath); assert(!error && "Error creating weights directory"); size_t indexFileName = 0; int64_t xbarSize = crossbarSize.getValue(); llvm::DenseMap> mapCoreWeightToFileName; llvm::DenseMap mapGlobalOpToFileName; SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); for (Operation* op : coreLikeOps) { auto processCore = [&](pim::PimCoreOp coreOp) { size_t coreId = static_cast(coreOp.getCoreId()); for (unsigned index : getUsedWeightIndices(coreOp)) { if (index >= coreOp.getWeights().size()) { coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); assert(index < coreOp.getWeights().size() && "Weight index is out of range"); } mlir::Value weight = coreOp.getWeights()[index]; auto weightView = resolveDenseWeightView(moduleOp, weight); if (failed(weightView)) { coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); } if (mapCoreWeightToFileName[coreId].contains(weight)) continue; auto getGlobalOp = weight.getDefiningOp(); auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { auto& fileName = mapGlobalOpToFileName[globalOp]; mapCoreWeightToFileName[coreId].insert({weight, fileName}); continue; } DenseElementsAttr denseAttr = weightView->denseAttr; ArrayRef shape = weightView->shape; assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); int64_t numRows = shape[0]; int64_t numCols = shape[1]; assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); std::error_code errorCode; raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); if (errorCode) { errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; assert(errorCode); } uint64_t zero = 0; for (int64_t row = 0; row < xbarSize; row++) { for (int64_t col = 0; col < xbarSize; col++) { if (row < numRows && col < numCols) { int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); uint64_t word = bits.getZExtValue(); weightFileStream.write(reinterpret_cast(&word), elementByteWidth); } else { weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); } } } weightFileStream.close(); if (globalOp) mapGlobalOpToFileName.insert({globalOp, newFileName}); mapCoreWeightToFileName[coreId].insert({weight, newFileName}); } return success(); }; if (auto coreOp = dyn_cast(op)) { (void) processCore(coreOp); continue; } auto coreBatchOp = cast(op); for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore))) return mapCoreWeightToFileName; } return mapCoreWeightToFileName; } } // namespace onnx_mlir