#include "../Common.hpp" #include "../Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { template static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, Value dst, Value src, int64_t dstOffset, int64_t srcOffset, int64_t size, PatternRewriter& rewriter, CreateCopyOp createCopyOp) { auto srcSubview = getStaticSubviewInfo(src); auto dstSubview = getStaticSubviewInfo(dst); const bool splitSrc = succeeded(srcSubview) && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); const bool splitDst = succeeded(dstSubview) && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); if (!splitSrc && !splitDst) return failure(); auto sourceType = dyn_cast(src.getType()); auto dstType = dyn_cast(dst.getType()); if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); if (sourceType.getElementType() != dstType.getElementType()) return failure(); if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) return failure(); const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; if (elementByteWidth <= 0) return failure(); const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; if (size != totalBytes) return failure(); const int64_t sliceBytes = copyShape.back() * elementByteWidth; if (sliceBytes <= 0) return failure(); SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); auto outerStrides = computeRowMajorStrides(outerShape); const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); rewriter.setInsertionPoint(copyOp); for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices = outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); const int64_t srcByteOffset = srcOffset + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); const int64_t dstByteOffset = dstOffset + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); createCopyOp(splitDst ? cast(dstSubview->source.getType()) : dstType, splitDst ? dstSubview->source : dst, splitSrc ? srcSubview->source : src, dstByteOffset, srcByteOffset, sliceBytes); } return success(); } struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { if (!copyOp->getParentOfType()) return failure(); auto status = rewriteSubviewCopyLikeOp(copyOp, copyOp.getDst(), copyOp.getSrc(), copyOp.getDstOffset(), copyOp.getSrcOffset(), copyOp.getSize(), rewriter, [&](MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { pim::PimMemCopyOp::create( rewriter, copyOp.getLoc(), resultType, dst, src, rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) return failure(); rewriter.replaceOp(copyOp, copyOp.getDst()); return success(); } }; struct RewriteHostSubviewLoadPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { auto status = rewriteSubviewCopyLikeOp(copyOp, copyOp.getDeviceDst(), copyOp.getHostSrc(), copyOp.getDeviceDstOffset(), copyOp.getHostSrcOffset(), copyOp.getSize(), rewriter, [&](MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { pim::PimMemCopyHostToDevOp::create( rewriter, copyOp.getLoc(), resultType, dst, src, rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) return failure(); rewriter.replaceOp(copyOp, copyOp.getDeviceDst()); return success(); } }; struct FoldConstantCoreSubviewPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override { if (subviewOp.use_empty()) return failure(); if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa(user); })) return failure(); auto moduleOp = subviewOp->getParentOfType(); if (!moduleOp) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource())); if (failed(denseAttr)) return failure(); auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult()); if (failed(subviewInfo)) return failure(); if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; })) return failure(); auto sourceType = dyn_cast(denseAttr->getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); auto elementType = cast(subviewOp.getType()).getElementType(); auto resultMemRefType = MemRefType::get(SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType); const int64_t numResultElements = resultTensorType.getNumElements(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); SmallVector sourceValues(denseAttr->getValues()); SmallVector resultValues(numResultElements); for (int64_t i = 0; i < numResultElements; ++i) { auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); SmallVector sourceIndices; sourceIndices.reserve(resultIndices.size()); for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices)) sourceIndices.push_back(off + idx); resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; } auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); markWeightAlways(newGlobal); rewriter.setInsertionPoint(subviewOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); markWeightAlways(newGetGlobal); rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); return success(); } }; } // namespace void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) { patterns.add( patterns.getContext()); } } // namespace onnx_mlir