#include "../Common.hpp" #include "../Patterns.hpp" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Matchers.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct ConstantSubviewCopy { DenseElementsAttr source; SmallVector offsets; SmallVector strides; Operation* copyOp = nullptr; }; static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { auto tensorType = dyn_cast(denseAttr.getType()); if (!tensorType) return failure(); int64_t rank = tensorType.getRank(); if (static_cast(perms.size()) != rank) return failure(); llvm::SmallBitVector seen(rank); SmallVector transposedShape; transposedShape.reserve(rank); for (int64_t perm : perms) { if (perm < 0 || perm >= rank || seen.test(perm)) return failure(); seen.set(perm); transposedShape.push_back(tensorType.getShape()[perm]); } auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType()); if (denseAttr.isSplat()) return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue()); SmallVector originalValues(denseAttr.getValues()); SmallVector transposedValues(originalValues.size()); SmallVector originalStrides(rank, 1); SmallVector transposedStrides(rank, 1); for (int64_t dim = rank - 2; dim >= 0; --dim) { originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1]; transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1]; } SmallVector originalIndices(rank); SmallVector transposedIndices(rank); for (auto [linearIndex, value] : llvm::enumerate(originalValues)) { int64_t remaining = static_cast(linearIndex); for (int64_t dim = 0; dim < rank; ++dim) { originalIndices[dim] = remaining / originalStrides[dim]; remaining %= originalStrides[dim]; } for (int64_t dim = 0; dim < rank; ++dim) transposedIndices[dim] = originalIndices[perms[dim]]; int64_t transposedLinearIndex = 0; for (int64_t dim = 0; dim < rank; ++dim) transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim]; transposedValues[transposedLinearIndex] = value; } return DenseElementsAttr::get(transposedType, transposedValues); } static FailureOr getConstantMapYield(linalg::MapOp mapOp) { if (!mapOp.getInputs().empty()) return failure(); auto yieldOp = dyn_cast(mapOp.getMapper().front().getTerminator()); if (!yieldOp || yieldOp.getNumOperands() != 1) return failure(); Attribute attr; if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr))) return failure(); return attr; } struct FoldConstantCoreMapPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override { auto coreOp = mapOp->getParentOfType(); if (!coreOp) return failure(); auto initType = dyn_cast(mapOp.getInit().getType()); if (!initType || !initType.hasStaticShape()) return failure(); auto fillValue = getConstantMapYield(mapOp); if (failed(fillValue)) return failure(); auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType()); DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue); auto moduleOp = mapOp->getParentOfType(); if (!moduleOp) return failure(); auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(coreOp); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); size_t elementByteWidth = initType.getElementTypeBitWidth() / 8; if (elementByteWidth == 0) return failure(); size_t totalBytes = initType.getNumElements() * elementByteWidth; rewriter.setInsertionPoint(mapOp); pim::PimMemCopyHostToDevOp::create(rewriter, mapOp.getLoc(), initType, mapOp.getInit(), getGlobalOp.getResult(), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(totalBytes))); rewriter.eraseOp(mapOp); return success(); } }; static FailureOr foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) { auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) return failure(); auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); const int64_t numElements = resultTensorType.getNumElements(); if (numElements < 0) return failure(); Attribute fillValue; SmallVector copies; llvm::SmallPtrSet visitedAliases; SmallVector pendingAliases; pendingAliases.push_back(allocOp.getResult()); while (!pendingAliases.empty()) { Value alias = pendingAliases.pop_back_val(); for (Operation* user : alias.getUsers()) { if (!visitedAliases.insert(user).second) continue; if (auto mapOp = dyn_cast(user)) { if (mapOp.getInit() != alias) return failure(); auto maybeFillValue = getConstantMapYield(mapOp); if (failed(maybeFillValue)) return failure(); if (fillValue && fillValue != *maybeFillValue) return failure(); fillValue = *maybeFillValue; continue; } if (auto subviewOp = dyn_cast(user)) { SmallVector offsets; SmallVector strides; offsets.reserve(subviewOp.getMixedOffsets().size()); strides.reserve(subviewOp.getMixedStrides().size()); for (OpFoldResult offset : subviewOp.getMixedOffsets()) { auto staticOffset = getConstantIntValue(offset); if (!staticOffset) return failure(); offsets.push_back(*staticOffset); } for (OpFoldResult stride : subviewOp.getMixedStrides()) { auto staticStride = getConstantIntValue(stride); if (!staticStride) return failure(); strides.push_back(*staticStride); } for (Operation* subviewUser : subviewOp->getUsers()) { if (auto copyOp = dyn_cast(subviewUser)) { if (copyOp.getTarget() != subviewOp.getResult()) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource()); if (failed(denseAttr)) return failure(); copies.push_back({*denseAttr, offsets, strides, copyOp}); continue; } return failure(); } continue; } if (isa(user)) continue; if (auto castOp = dyn_cast(user)) { pendingAliases.push_back(castOp.getResult()); continue; } return failure(); } } if (!fillValue) return failure(); SmallVector resultValues(numElements, fillValue); auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) { return lhs.copyOp->isBeforeInBlock(rhs.copyOp); }); for (const ConstantSubviewCopy& copy : copies) { auto sourceType = dyn_cast(copy.source.getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); if (sourceType.getRank() != static_cast(copy.offsets.size()) || sourceType.getRank() != static_cast(copy.strides.size())) return failure(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); SmallVector sourceValues(copy.source.getValues()); for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) { SmallVector sourceIndices = delinearizeIndex(static_cast(linearIndex), sourceType.getShape(), sourceStrides); SmallVector resultIndices; resultIndices.reserve(sourceIndices.size()); for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides)) resultIndices.push_back(offset + sourceIndex * stride); int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides); resultValues[resultLinearIndex] = value; } } return DenseElementsAttr::get(resultTensorType, resultValues); } struct FoldConstantTransposePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override { auto resultType = dyn_cast(transposeOp.getOutRes().getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); auto sourceGetGlobal = transposeOp.getData().getDefiningOp(); if (!sourceGetGlobal) return failure(); auto moduleOp = transposeOp->getParentOfType(); if (!moduleOp) return failure(); auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal); if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue()) return failure(); auto denseAttr = dyn_cast(*sourceGlobal.getInitialValue()); if (!denseAttr) return failure(); SmallVector perms; perms.reserve(transposeOp.getPerms().size()); for (IntegerAttr attr : transposeOp.getPerms().getAsRange()) perms.push_back(attr.getInt()); FailureOr transposedAttr = transposeDenseElements(denseAttr, perms); if (failed(transposedAttr)) return failure(); auto transposedShape = cast(transposedAttr->getType()).getShape(); if (!llvm::equal(transposedShape, resultType.getShape())) return failure(); auto newGlobal = createFoldedGlobal(moduleOp, transposeOp.getLoc(), resultType, *transposedAttr, sourceGlobal.getName().str() + "__folded_transpose", sourceGlobal.getAlignmentAttr()); rewriter.setInsertionPoint(transposeOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName()); bool isAlwaysWeight = !transposeOp->getUsers().empty() && llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa(user); }); if (isAlwaysWeight) { markWeightAlways(newGlobal); markWeightAlways(newGetGlobal); } rewriter.replaceOp(transposeOp, newGetGlobal.getResult()); return success(); } }; struct FoldConstantAllocPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override { auto moduleOp = allocOp->getParentOfType(); if (!moduleOp) return failure(); auto foldedAttr = foldConstantAlloc(allocOp, moduleOp); if (failed(foldedAttr)) return failure(); auto allocType = cast(allocOp.getType()); auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant"); rewriter.setInsertionPoint(allocOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName()); SmallVector opsToErase; SmallVector castsToReplace; bool allLiveUsersAreCoreOps = true; for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) { if (isa(user)) { opsToErase.push_back(user); continue; } if (auto castOp = dyn_cast(user)) { castsToReplace.push_back(castOp); continue; } if (!isa(user)) return failure(); } if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) { return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa(user); }); })) { allLiveUsersAreCoreOps = false; } if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) { return isa(user); })) { return failure(); } if (allLiveUsersAreCoreOps) { markWeightAlways(newGlobal); markWeightAlways(newGetGlobal); } llvm::SmallPtrSet preservedUsers(opsToErase.begin(), opsToErase.end()); for (memref::CastOp castOp : castsToReplace) preservedUsers.insert(castOp); rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers); for (memref::CastOp castOp : castsToReplace) { rewriter.setInsertionPoint(castOp); Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal); rewriter.replaceOp(castOp, replacementCast); if (allLiveUsersAreCoreOps) markWeightAlways(replacementCast.getDefiningOp()); } for (Operation* op : llvm::make_early_inc_range(opsToErase)) { if (auto subviewOp = dyn_cast(op)) for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers())) rewriter.eraseOp(subviewUser); if (op->use_empty()) rewriter.eraseOp(op); } if (allocOp.use_empty()) rewriter.eraseOp(allocOp); return success(); } }; struct FoldConstantMemCpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { if (copyOp->getParentOfType()) return failure(); auto allocOp = copyOp.getDst().getDefiningOp(); if (!allocOp) return failure(); auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) return failure(); if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0) return failure(); auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc()); auto moduleOp = copyOp->getParentOfType(); if (!moduleOp) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); if (failed(denseAttr)) return failure(); DenseElementsAttr foldedAttr; if (succeeded(srcSubview)) { auto sourceType = dyn_cast(denseAttr->getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); const int64_t numResultElements = resultTensorType.getNumElements(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); SmallVector sourceValues(denseAttr->getValues()); SmallVector resultValues(numResultElements); for (int64_t i = 0; i < numResultElements; ++i) { auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); SmallVector sourceIndices; sourceIndices.reserve(resultIndices.size()); for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices)) sourceIndices.push_back(off + idx); int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides); resultValues[i] = sourceValues[srcLinear]; } foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); } else { auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); if (resultTensorType != denseAttr->getType()) return failure(); foldedAttr = *denseAttr; } bool allLiveUsersAreCores = true; for (Operation* user : allocOp->getUsers()) { if (user == copyOp) continue; if (isa(user)) continue; if (isa(user)) continue; if (isa(user)) { allLiveUsersAreCores = false; continue; } return failure(); } auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp"); if (allLiveUsersAreCores) markWeightAlways(newGlobal); rewriter.setInsertionPoint(allocOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName()); if (allLiveUsersAreCores) markWeightAlways(newGetGlobal); rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult()); rewriter.eraseOp(copyOp); if (allocOp.use_empty()) rewriter.eraseOp(allocOp); return success(); } }; } // namespace void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } } // namespace onnx_mlir