#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "../Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool isSubviewContiguous(const StaticSubviewInfo& info) { if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; })) return false; auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()), llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend())); auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { auto [size, dimension] = sizeAndShape; return size != dimension; }); if (firstDifferentSize == sizesAndShape.end()) return true; ++firstDifferentSize; return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) { auto [size, _dimension] = sizeAndShape; return size == 1; }); } static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) { if (extraOffset == 0) return baseOffset; if (auto attr = dyn_cast(baseOffset)) { auto integerAttr = dyn_cast(attr); assert(integerAttr && "expected integer offset attribute"); return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset); } auto value = cast(baseOffset); auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset); return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); } static Value buildSubviewChunk(const StaticSubviewInfo& info, ArrayRef outerIndices, Location loc, PatternRewriter& rewriter) { SmallVector chunkOffsets; SmallVector chunkSizes; SmallVector chunkStrides; chunkOffsets.reserve(info.offsets.size()); chunkSizes.reserve(info.sizes.size()); chunkStrides.reserve(info.strides.size()); for (size_t dim = 0; dim < info.sizes.size(); ++dim) { int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0; chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter)); chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back())); chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); } return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); } static SmallVector delinearizeIndexValue(Value linearIndex, ArrayRef shape, ArrayRef strides, PatternRewriter& rewriter) { SmallVector indices; indices.reserve(shape.size()); Value remaining = linearIndex; for (auto [_dim, stride] : llvm::enumerate(strides)) { auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride); Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); indices.push_back(index); remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); } return indices; } static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) { if (auto attr = dyn_cast(baseOffset)) { auto integerAttr = cast(attr); if (integerAttr.getInt() == 0) return extraOffset; auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt()); return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); } auto value = cast(baseOffset); return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult(); } static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info, ArrayRef outerIndices, Location loc, PatternRewriter& rewriter) { SmallVector chunkOffsets; SmallVector chunkSizes; SmallVector chunkStrides; chunkOffsets.reserve(info.offsets.size()); chunkSizes.reserve(info.sizes.size()); chunkStrides.reserve(info.strides.size()); for (size_t dim = 0; dim < info.sizes.size(); ++dim) { if (dim + 1 < info.sizes.size()) { assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); chunkSizes.push_back(rewriter.getIndexAttr(1)); } else { chunkOffsets.push_back(info.offsets[dim]); chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); } chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); } return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); } static Value buildContiguousChunk( Value source, ArrayRef copyShape, ArrayRef outerIndices, Location loc, PatternRewriter& rewriter) { SmallVector chunkOffsets; SmallVector chunkSizes; SmallVector chunkStrides; chunkOffsets.reserve(copyShape.size()); chunkSizes.reserve(copyShape.size()); chunkStrides.reserve(copyShape.size()); for (size_t dim = 0; dim < copyShape.size(); ++dim) { chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0)); chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back())); chunkStrides.push_back(rewriter.getIndexAttr(1)); } return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides); } template static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, Value dst, Value src, int64_t dstOffset, int64_t srcOffset, int64_t size, bool allowLoopRewrite, PatternRewriter& rewriter, CreateCopyOp createCopyOp) { auto srcSubview = getStaticSubviewInfo(src); auto dstSubview = getStaticSubviewInfo(dst); const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview); const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview); if (!splitSrc && !splitDst) return failure(); auto sourceType = dyn_cast(src.getType()); auto dstType = dyn_cast(dst.getType()); if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); if (sourceType.getElementType() != dstType.getElementType()) return failure(); if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))) return failure(); if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))) return failure(); ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) return failure(); const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; if (elementByteWidth <= 0) return failure(); const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; if (size != totalBytes) return failure(); const int64_t sliceBytes = copyShape.back() * elementByteWidth; if (sliceBytes <= 0) return failure(); SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); auto outerStrides = computeRowMajorStrides(outerShape); const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast(copyShape.size()) && dstType.getRank() == static_cast(copyShape.size())) { auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0); auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices); auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1); auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); rewriter.setInsertionPointToStart(loop.getBody()); SmallVector outerIndices = outerShape.empty() ? SmallVector {} : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter); createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes); return success(); } rewriter.setInsertionPoint(copyOp); for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices = outerShape.empty() ? SmallVector {} : delinearizeIndex(linearIndex, outerShape, outerStrides); Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst; Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src; const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes; const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes; createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes); } return success(); } // Splits core copies through subviews into contiguous copy chunks for codegen. struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { if (!copyOp->getParentOfType()) return failure(); auto status = rewriteSubviewCopyLikeOp( copyOp, copyOp.getTarget(), copyOp.getSource(), copyOp.getTargetOffset(), copyOp.getSourceOffset(), copyOp.getSize(), /*allowLoopRewrite=*/true, rewriter, [&]( MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { pim::PimMemCopyOp::create(rewriter, copyOp.getLoc(), resultType, dst, src, rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) return failure(); rewriter.replaceOp(copyOp, copyOp.getTarget()); return success(); } }; // Splits host-to-device subview loads into contiguous copy chunks. struct RewriteHostSubviewLoadPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { auto status = rewriteSubviewCopyLikeOp( copyOp, copyOp.getDeviceTarget(), copyOp.getHostSource(), copyOp.getDeviceTargetOffset(), copyOp.getHostSourceOffset(), copyOp.getSize(), /*allowLoopRewrite=*/true, rewriter, [&]( MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { pim::PimMemCopyHostToDevOp::create(rewriter, copyOp.getLoc(), resultType, dst, src, rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) return failure(); rewriter.replaceOp(copyOp, copyOp.getDeviceTarget()); return success(); } }; // Splits device-to-host subview stores into contiguous copy chunks. struct RewriteHostSubviewStorePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override { auto status = rewriteSubviewCopyLikeOp( copyOp, copyOp.getHostTarget(), copyOp.getDeviceSource(), copyOp.getHostTargetOffset(), copyOp.getDeviceSourceOffset(), copyOp.getSize(), /*allowLoopRewrite=*/false, rewriter, [&]( MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { pim::PimMemCopyDevToHostOp::create(rewriter, copyOp.getLoc(), resultType, dst, src, rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) return failure(); rewriter.replaceOp(copyOp, copyOp.getHostTarget()); return success(); } }; // Folds constant subviews used as core weights into standalone globals. struct FoldConstantCoreSubviewPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override { if (subviewOp.use_empty()) return failure(); if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa(user); })) return failure(); auto moduleOp = subviewOp->getParentOfType(); if (!moduleOp) return failure(); auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource())); if (failed(denseAttr)) return failure(); auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult()); if (failed(subviewInfo)) return failure(); if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; })) return failure(); auto staticOffsets = getStaticSubviewOffsets(*subviewInfo); if (failed(staticOffsets)) return failure(); auto elementType = cast(subviewOp.getType()).getElementType(); auto resultMemRefType = MemRefType::get(SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape()); if (failed(foldedAttr)) return failure(); auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview"); markWeightAlways(newGlobal); rewriter.setInsertionPoint(subviewOp); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); markWeightAlways(newGetGlobal); rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); return success(); } }; } // namespace void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } } // namespace onnx_mlir