#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static Value stripMemRefCasts(Value value) { while (auto castOp = value.getDefiningOp()) value = castOp.getSource(); return value; } static Value stripMemRefViewOps(Value value) { while (true) { if (auto castOp = value.getDefiningOp()) { value = castOp.getSource(); continue; } if (auto collapseOp = value.getDefiningOp()) { value = collapseOp.getSrc(); continue; } if (auto expandOp = value.getDefiningOp()) { value = expandOp.getSrc(); continue; } return value; } } static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, Location loc, MemRefType globalType, DenseElementsAttr denseAttr, StringRef nameStem, IntegerAttr alignment = {}) { auto globalName = nameStem.str(); unsigned suffix = 0; while (moduleOp.lookupSymbol(globalName)) globalName = (nameStem + "_" + std::to_string(++suffix)).str(); auto visibility = StringAttr::get(moduleOp.getContext(), "private"); OpBuilder moduleBuilder(moduleOp.getBodyRegion()); moduleBuilder.setInsertionPointToStart(moduleOp.getBody()); return memref::GlobalOp::create(moduleBuilder, loc, globalName, visibility, globalType, denseAttr, /*constant=*/true, alignment); } static FailureOr getDenseGlobalValue(ModuleOp moduleOp, Value value) { value = stripMemRefCasts(value); auto getGlobalOp = value.getDefiningOp(); if (!getGlobalOp) return failure(); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) return failure(); auto denseAttr = dyn_cast(*globalOp.getInitialValue()); if (!denseAttr) return failure(); return denseAttr; } static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { auto tensorType = dyn_cast(denseAttr.getType()); if (!tensorType) return failure(); int64_t rank = tensorType.getRank(); if (static_cast(perms.size()) != rank) return failure(); llvm::SmallBitVector seen(rank); SmallVector transposedShape; transposedShape.reserve(rank); for (int64_t perm : perms) { if (perm < 0 || perm >= rank || seen.test(perm)) return failure(); seen.set(perm); transposedShape.push_back(tensorType.getShape()[perm]); } auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType()); if (denseAttr.isSplat()) return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue()); SmallVector originalValues(denseAttr.getValues()); SmallVector transposedValues(originalValues.size()); SmallVector originalStrides(rank, 1); SmallVector transposedStrides(rank, 1); for (int64_t dim = rank - 2; dim >= 0; --dim) { originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1]; transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1]; } SmallVector originalIndices(rank); SmallVector transposedIndices(rank); for (auto [linearIndex, value] : llvm::enumerate(originalValues)) { int64_t remaining = static_cast(linearIndex); for (int64_t dim = 0; dim < rank; ++dim) { originalIndices[dim] = remaining / originalStrides[dim]; remaining %= originalStrides[dim]; } for (int64_t dim = 0; dim < rank; ++dim) transposedIndices[dim] = originalIndices[perms[dim]]; int64_t transposedLinearIndex = 0; for (int64_t dim = 0; dim < rank; ++dim) transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim]; transposedValues[transposedLinearIndex] = value; } return DenseElementsAttr::get(transposedType, transposedValues); } struct ConstantSubviewCopy { DenseElementsAttr source; SmallVector offsets; SmallVector strides; Operation* copyOp = nullptr; }; static FailureOr getConstantMapYield(linalg::MapOp mapOp) { if (!mapOp.getInputs().empty()) return failure(); auto yieldOp = dyn_cast(mapOp.getMapper().front().getTerminator()); if (!yieldOp || yieldOp.getNumOperands() != 1) return failure(); Attribute attr; if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr))) return failure(); return attr; } struct FoldConstantCoreMapPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override { auto coreOp = mapOp->getParentOfType(); if (!coreOp) return failure(); auto initType = dyn_cast(mapOp.getInit().getType()); if (!initType || !initType.hasStaticShape()) return failure(); auto fillValue = getConstantMapYield(mapOp); if (failed(fillValue)) return failure(); auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType()); DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue); auto moduleOp = mapOp->getParentOfType(); if (!moduleOp) return failure(); auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(coreOp); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); size_t elementByteWidth = initType.getElementTypeBitWidth() / 8; if (elementByteWidth == 0) return failure(); size_t totalBytes = initType.getNumElements() * elementByteWidth; rewriter.setInsertionPoint(mapOp); pim::PimMemCopyHostToDevOp::create(rewriter, mapOp.getLoc(), initType, mapOp.getInit(), getGlobalOp.getResult(), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(totalBytes))); rewriter.eraseOp(mapOp); return success(); } }; struct StaticSubviewInfo { Value source; SmallVector sourceShape; SmallVector offsets; SmallVector sizes; SmallVector strides; }; static FailureOr getStaticSubviewInfo(Value value) { value = stripMemRefViewOps(value); auto subviewOp = value.getDefiningOp(); if (!subviewOp) return failure(); auto source = stripMemRefCasts(subviewOp.getSource()); auto sourceType = dyn_cast(source.getType()); auto subviewType = dyn_cast(subviewOp.getType()); if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return failure(); StaticSubviewInfo info; info.source = source; info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); for (OpFoldResult offset : subviewOp.getMixedOffsets()) { auto staticOffset = getConstantIntValue(offset); if (!staticOffset) return failure(); info.offsets.push_back(*staticOffset); } for (OpFoldResult size : subviewOp.getMixedSizes()) { auto staticSize = getConstantIntValue(size); if (!staticSize) return failure(); info.sizes.push_back(*staticSize); } for (OpFoldResult stride : subviewOp.getMixedStrides()) { auto staticStride = getConstantIntValue(stride); if (!staticStride) return failure(); info.strides.push_back(*staticStride); } return info; } static int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef outerIndices, int64_t elementByteWidth) { SmallVector sourceIndices; sourceIndices.reserve(info.sourceShape.size()); for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim) sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]); sourceIndices.push_back(info.offsets.back()); return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth; } struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { if (!copyOp->getParentOfType()) return failure(); auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); auto dstSubview = getStaticSubviewInfo(copyOp.getDst()); const bool splitSrc = succeeded(srcSubview) && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); const bool splitDst = succeeded(dstSubview) && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); if (!splitSrc && !splitDst) return failure(); auto sourceType = dyn_cast(copyOp.getSrc().getType()); auto dstType = dyn_cast(copyOp.getDst().getType()); if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); if (sourceType.getElementType() != dstType.getElementType()) return failure(); if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) return failure(); const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; if (elementByteWidth <= 0) return failure(); const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; if (copyOp.getSize() != totalBytes) return failure(); const int64_t sliceBytes = copyShape.back() * elementByteWidth; if (sliceBytes <= 0) return failure(); SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); auto outerStrides = computeRowMajorStrides(outerShape); const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); rewriter.setInsertionPoint(copyOp); for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices = outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); const int64_t srcByteOffset = copyOp.getSrcOffset() + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); const int64_t dstByteOffset = copyOp.getDstOffset() + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); pim::PimMemCopyOp::create(rewriter, copyOp.getLoc(), splitDst ? cast(dstSubview->source.getType()) : dstType, splitDst ? dstSubview->source : copyOp.getDst(), splitSrc ? srcSubview->source : copyOp.getSrc(), rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); } rewriter.replaceOp(copyOp, copyOp.getDst()); return success(); } }; struct RewriteHostSubviewLoadPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc()); auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst()); const bool splitSrc = succeeded(srcSubview) && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); const bool splitDst = succeeded(dstSubview) && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); if (!splitSrc && !splitDst) return failure(); auto sourceType = dyn_cast(copyOp.getHostSrc().getType()); auto dstType = dyn_cast(copyOp.getDeviceDst().getType()); if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); if (sourceType.getElementType() != dstType.getElementType()) return failure(); if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) return failure(); const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; if (elementByteWidth <= 0) return failure(); const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; if (copyOp.getSize() != totalBytes) return failure(); const int64_t sliceBytes = copyShape.back() * elementByteWidth; if (sliceBytes <= 0) return failure(); SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); auto outerStrides = computeRowMajorStrides(outerShape); const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); rewriter.setInsertionPoint(copyOp); for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices = outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); const int64_t srcByteOffset = copyOp.getHostSrcOffset() + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); const int64_t dstByteOffset = copyOp.getDeviceDstOffset() + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); pim::PimMemCopyHostToDevOp::create( rewriter, copyOp.getLoc(), splitDst ? cast(dstSubview->source.getType()) : dstType, splitDst ? dstSubview->source : copyOp.getDeviceDst(), splitSrc ? srcSubview->source : copyOp.getHostSrc(), rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); } rewriter.replaceOp(copyOp, copyOp.getDeviceDst()); return success(); } }; static FailureOr foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) { auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) return failure(); auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); const int64_t numElements = resultTensorType.getNumElements(); if (numElements < 0) return failure(); Attribute fillValue; SmallVector copies; llvm::SmallPtrSet visitedAliases; SmallVector pendingAliases; pendingAliases.push_back(allocOp.getResult()); while (!pendingAliases.empty()) { Value alias = pendingAliases.pop_back_val(); for (Operation* user : alias.getUsers()) { if (!visitedAliases.insert(user).second) continue; if (auto mapOp = dyn_cast(user)) { if (mapOp.getInit() != alias) return failure(); auto maybeFillValue = getConstantMapYield(mapOp); if (failed(maybeFillValue)) return failure(); if (fillValue && fillValue != *maybeFillValue) return failure(); fillValue = *maybeFillValue; continue; } if (auto subviewOp = dyn_cast(user)) { SmallVector offsets; SmallVector strides; offsets.reserve(subviewOp.getMixedOffsets().size()); strides.reserve(subviewOp.getMixedStrides().size()); for (OpFoldResult offset : subviewOp.getMixedOffsets()) { auto staticOffset = getConstantIntValue(offset); if (!staticOffset) return failure(); offsets.push_back(*staticOffset); } for (OpFoldResult stride : subviewOp.getMixedStrides()) { auto staticStride = getConstantIntValue(stride); if (!staticStride) return failure(); strides.push_back(*staticStride); } for (Operation* subviewUser : subviewOp->getUsers()) { if (auto copyOp = dyn_cast(subviewUser)) { if (copyOp.getTarget() != subviewOp.getResult()) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource()); if (failed(denseAttr)) return failure(); copies.push_back({*denseAttr, offsets, strides, copyOp}); continue; } return failure(); } continue; } if (isa(user)) continue; if (auto castOp = dyn_cast(user)) { pendingAliases.push_back(castOp.getResult()); continue; } return failure(); } } if (!fillValue) return failure(); SmallVector resultValues(numElements, fillValue); auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) { return lhs.copyOp->isBeforeInBlock(rhs.copyOp); }); for (const ConstantSubviewCopy& copy : copies) { auto sourceType = dyn_cast(copy.source.getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); if (sourceType.getRank() != static_cast(copy.offsets.size()) || sourceType.getRank() != static_cast(copy.strides.size())) return failure(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); SmallVector sourceValues(copy.source.getValues()); for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) { SmallVector sourceIndices = delinearizeIndex(static_cast(linearIndex), sourceType.getShape(), sourceStrides); SmallVector resultIndices; resultIndices.reserve(sourceIndices.size()); for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides)) resultIndices.push_back(offset + sourceIndex * stride); int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides); resultValues[resultLinearIndex] = value; } } return DenseElementsAttr::get(resultTensorType, resultValues); } struct FoldConstantTransposePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override { auto resultType = dyn_cast(transposeOp.getOutRes().getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); auto sourceGetGlobal = transposeOp.getData().getDefiningOp(); if (!sourceGetGlobal) return failure(); auto moduleOp = transposeOp->getParentOfType(); if (!moduleOp) return failure(); auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal); if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue()) return failure(); auto denseAttr = dyn_cast(*sourceGlobal.getInitialValue()); if (!denseAttr) return failure(); SmallVector perms; perms.reserve(transposeOp.getPerms().size()); for (IntegerAttr attr : transposeOp.getPerms().getAsRange()) perms.push_back(attr.getInt()); FailureOr transposedAttr = transposeDenseElements(denseAttr, perms); if (failed(transposedAttr)) return failure(); auto transposedShape = cast(transposedAttr->getType()).getShape(); if (!llvm::equal(transposedShape, resultType.getShape())) return failure(); MemRefType globalType = resultType; auto newGlobal = createFoldedGlobal(moduleOp, transposeOp.getLoc(), globalType, *transposedAttr, sourceGlobal.getName().str() + "__folded_transpose", sourceGlobal.getAlignmentAttr()); rewriter.setInsertionPoint(transposeOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName()); bool isAlwaysWeight = !transposeOp->getUsers().empty() && llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa(user); }); if (isAlwaysWeight) { markWeightAlways(newGlobal); markWeightAlways(newGetGlobal); } rewriter.replaceOp(transposeOp, newGetGlobal.getResult()); return success(); } }; struct FoldConstantAllocPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override { auto moduleOp = allocOp->getParentOfType(); if (!moduleOp) return failure(); auto foldedAttr = foldConstantAlloc(allocOp, moduleOp); if (failed(foldedAttr)) return failure(); auto allocType = cast(allocOp.getType()); auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant"); rewriter.setInsertionPoint(allocOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName()); SmallVector opsToErase; SmallVector castsToReplace; bool allLiveUsersAreCoreOps = true; for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) { if (isa(user)) { opsToErase.push_back(user); continue; } if (auto castOp = dyn_cast(user)) { castsToReplace.push_back(castOp); continue; } if (!isa(user)) return failure(); } if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) { return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa(user); }); })) { allLiveUsersAreCoreOps = false; } if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) { return isa(user); })) { return failure(); } if (allLiveUsersAreCoreOps) { markWeightAlways(newGlobal); markWeightAlways(newGetGlobal); } llvm::SmallPtrSet preservedUsers(opsToErase.begin(), opsToErase.end()); for (memref::CastOp castOp : castsToReplace) preservedUsers.insert(castOp); rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers); for (memref::CastOp castOp : castsToReplace) { rewriter.setInsertionPoint(castOp); Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal); rewriter.replaceOp(castOp, replacementCast); if (allLiveUsersAreCoreOps) markWeightAlways(replacementCast.getDefiningOp()); } for (Operation* op : llvm::make_early_inc_range(opsToErase)) { if (auto subviewOp = dyn_cast(op)) for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers())) rewriter.eraseOp(subviewUser); if (op->use_empty()) rewriter.eraseOp(op); } if (allocOp.use_empty()) rewriter.eraseOp(allocOp); return success(); } }; struct FoldConstantMemCpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { // Only match top-level memcp (not inside pim.core) if (copyOp->getParentOfType()) return failure(); // dst must be an alloc with static shape auto allocOp = copyOp.getDst().getDefiningOp(); if (!allocOp) return failure(); auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) return failure(); // The copy must cover the full destination (offsets both zero) if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0) return failure(); // Resolve the source through an optional subview to a get_global auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc()); auto moduleOp = copyOp->getParentOfType(); if (!moduleOp) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); if (failed(denseAttr)) return failure(); // Build the folded dense attribute DenseElementsAttr foldedAttr; if (succeeded(srcSubview)) { // Extract the sub-tensor from the source constant auto sourceType = dyn_cast(denseAttr->getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; })) return failure(); auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); const int64_t numResultElements = resultTensorType.getNumElements(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); SmallVector sourceValues(denseAttr->getValues()); SmallVector resultValues(numResultElements); for (int64_t i = 0; i < numResultElements; ++i) { auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); SmallVector sourceIndices; sourceIndices.reserve(resultIndices.size()); for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices)) sourceIndices.push_back(off + idx); int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides); resultValues[i] = sourceValues[srcLinear]; } foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); } else { // Direct copy from a global — just reuse its dense attribute auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); if (resultTensorType != denseAttr->getType()) return failure(); foldedAttr = *denseAttr; } // Verify that the alloc's remaining users are supported ops. bool allLiveUsersAreCores = true; for (Operation* user : allocOp->getUsers()) { if (user == copyOp) continue; if (isa(user)) continue; if (isa(user)) continue; if (isa(user)) { allLiveUsersAreCores = false; continue; } return failure(); } auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp"); if (allLiveUsersAreCores) markWeightAlways(newGlobal); rewriter.setInsertionPoint(allocOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName()); if (allLiveUsersAreCores) markWeightAlways(newGetGlobal); rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult()); rewriter.eraseOp(copyOp); if (allocOp.use_empty()) rewriter.eraseOp(allocOp); return success(); } }; struct FoldConstantCoreSubviewPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override { // Only handle subviews whose users are all pim.core ops. if (subviewOp.use_empty()) return failure(); if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa(user); })) return failure(); // Source must resolve to a constant get_global. auto moduleOp = subviewOp->getParentOfType(); if (!moduleOp) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource())); if (failed(denseAttr)) return failure(); // Static subview info. auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult()); if (failed(subviewInfo)) return failure(); if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; })) return failure(); auto sourceType = dyn_cast(denseAttr->getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); // Build the contiguous result type. auto elementType = cast(subviewOp.getType()).getElementType(); auto resultMemRefType = MemRefType::get( SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType); const int64_t numResultElements = resultTensorType.getNumElements(); // Extract the sub-tensor. auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); SmallVector sourceValues(denseAttr->getValues()); SmallVector resultValues(numResultElements); for (int64_t i = 0; i < numResultElements; ++i) { auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); SmallVector sourceIndices; sourceIndices.reserve(resultIndices.size()); for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices)) sourceIndices.push_back(off + idx); resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; } auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); markWeightAlways(newGlobal); rewriter.setInsertionPoint(subviewOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); markWeightAlways(newGetGlobal); rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); return success(); } }; struct PimConstantFoldingPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) StringRef getArgument() const override { return "pim-constant-folding-pass"; } StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } LogicalResult initialize(MLIRContext* context) override { RewritePatternSet owningPatterns(context); for (auto* dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(owningPatterns); for (RegisteredOperationName op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); owningPatterns .add( context); patterns = std::make_shared(std::move(owningPatterns)); return success(); } void runOnOperation() override { GreedyRewriteConfig config; config.enableFolding(); if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) { signalPassFailure(); return; } dumpModule(getOperation(), "pim2_folded"); } std::shared_ptr patterns; }; } // namespace std::unique_ptr createPimConstantFoldingPass() { return std::make_unique(); } } // namespace onnx_mlir