diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp index c0b3b8f..0a8e08c 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp @@ -1,10 +1,45 @@ #include "Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" + using namespace mlir; namespace onnx_mlir { +namespace { + +struct DenseSubviewKey { + DenseElementsAttr source; + SmallVector offsets; + SmallVector resultShape; + + bool operator==(const DenseSubviewKey& other) const { + return source == other.source && offsets == other.offsets && resultShape == other.resultShape; + } +}; + +struct DenseSubviewKeyInfo { + static inline DenseSubviewKey getEmptyKey() { + return {DenseElementsAttr(), {DenseMapInfo::getEmptyKey()}, {}}; + } + + static inline DenseSubviewKey getTombstoneKey() { + return {DenseElementsAttr(), {DenseMapInfo::getTombstoneKey()}, {}}; + } + + static unsigned getHashValue(const DenseSubviewKey& key) { + return static_cast( + llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()), + llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end()))); + } + + static bool isEqual(const DenseSubviewKey& lhs, const DenseSubviewKey& rhs) { return lhs == rhs; } +}; + +} // namespace + Value stripMemRefCasts(Value value) { while (auto castOp = value.getDefiningOp()) value = castOp.getSource(); @@ -35,6 +70,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, DenseElementsAttr denseAttr, StringRef nameStem, IntegerAttr alignment) { + for (auto globalOp : moduleOp.getOps()) { + if (!globalOp.getConstant() || globalOp.getType() != globalType || globalOp.getAlignmentAttr() != alignment + || !globalOp.getInitialValue()) + continue; + + auto existingDenseAttr = dyn_cast(*globalOp.getInitialValue()); + if (existingDenseAttr == denseAttr) + return globalOp; + } + auto globalName = nameStem.str(); unsigned suffix = 0; while (moduleOp.lookupSymbol(globalName)) @@ -53,6 +98,43 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, alignment); } +FailureOr foldDenseSubview(DenseElementsAttr denseAttr, + ArrayRef staticOffsets, + ArrayRef resultShape) { + auto sourceType = dyn_cast(denseAttr.getType()); + if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast(staticOffsets.size()) + || sourceType.getRank() != static_cast(resultShape.size())) + return failure(); + + static DenseMap cache; + DenseSubviewKey key {denseAttr, SmallVector(staticOffsets.begin(), staticOffsets.end()), + SmallVector(resultShape.begin(), resultShape.end())}; + if (auto cached = cache.find(key); cached != cache.end()) + return cached->second; + + auto resultTensorType = RankedTensorType::get(resultShape, sourceType.getElementType()); + const int64_t numResultElements = resultTensorType.getNumElements(); + if (numResultElements < 0) + return failure(); + + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); + SmallVector sourceValues(denseAttr.getValues()); + SmallVector resultValues(numResultElements); + for (int64_t i = 0; i < numResultElements; ++i) { + auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); + SmallVector sourceIndices; + sourceIndices.reserve(resultIndices.size()); + for (auto [offset, index] : llvm::zip_equal(staticOffsets, resultIndices)) + sourceIndices.push_back(offset + index); + resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; + } + + auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); + cache.try_emplace(std::move(key), foldedAttr); + return foldedAttr; +} + FailureOr getDenseGlobalValue(ModuleOp moduleOp, Value value) { value = stripMemRefCasts(value); diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp index f41c4a0..c355de9 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp @@ -30,6 +30,10 @@ mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp, llvm::StringRef nameStem, mlir::IntegerAttr alignment = {}); +llvm::FailureOr foldDenseSubview(mlir::DenseElementsAttr denseAttr, + llvm::ArrayRef staticOffsets, + llvm::ArrayRef resultShape); + llvm::FailureOr getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value); llvm::FailureOr getStaticSubviewInfo(mlir::Value value); diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp index ef1adb2..4c39ac1 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp @@ -419,32 +419,16 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { DenseElementsAttr foldedAttr; if (succeeded(srcSubview)) { - auto sourceType = dyn_cast(denseAttr->getType()); - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); auto staticOffsets = getStaticSubviewOffsets(*srcSubview); if (failed(staticOffsets)) return failure(); - auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); - const int64_t numResultElements = resultTensorType.getNumElements(); - auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); - auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); - SmallVector sourceValues(denseAttr->getValues()); - SmallVector resultValues(numResultElements); - - for (int64_t i = 0; i < numResultElements; ++i) { - auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); - SmallVector sourceIndices; - sourceIndices.reserve(resultIndices.size()); - for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices)) - sourceIndices.push_back(off + idx); - int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides); - resultValues[i] = sourceValues[srcLinear]; - } - foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); + auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape()); + if (failed(maybeFoldedAttr)) + return failure(); + foldedAttr = *maybeFoldedAttr; } else { auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index acdbc58..586522d 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -249,32 +249,15 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern(denseAttr->getType()); - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); - auto elementType = cast(subviewOp.getType()).getElementType(); auto resultMemRefType = MemRefType::get(SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); - auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType); - const int64_t numResultElements = resultTensorType.getNumElements(); - - auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); - auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); - SmallVector sourceValues(denseAttr->getValues()); - SmallVector resultValues(numResultElements); - for (int64_t i = 0; i < numResultElements; ++i) { - auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); - SmallVector sourceIndices; - sourceIndices.reserve(resultIndices.size()); - for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices)) - sourceIndices.push_back(off + idx); - resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; - } - auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); + auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape()); + if (failed(foldedAttr)) + return failure(); auto newGlobal = - createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); + createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview"); markWeightAlways(newGlobal); rewriter.setInsertionPoint(subviewOp);