diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index d6bddc9..592d7c5 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "../Common.hpp" -#include "../Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -66,6 +66,88 @@ static Value buildSubviewChunk(const StaticSubviewInfo& info, return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); } +static SmallVector +delinearizeIndexValue(Value linearIndex, ArrayRef shape, ArrayRef strides, PatternRewriter& rewriter) { + SmallVector indices; + indices.reserve(shape.size()); + + Value remaining = linearIndex; + for (auto [dim, stride] : llvm::enumerate(strides)) { + if (stride == 1) { + indices.push_back(remaining); + continue; + } + + auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride); + Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); + indices.push_back(index); + remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); + } + + return indices; +} + +static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) { + if (auto attr = dyn_cast(baseOffset)) { + auto integerAttr = cast(attr); + if (integerAttr.getInt() == 0) + return extraOffset; + + auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt()); + return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); + } + + auto value = cast(baseOffset); + return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult(); +} + +static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info, + ArrayRef outerIndices, + Location loc, + PatternRewriter& rewriter) { + SmallVector chunkOffsets; + SmallVector chunkSizes; + SmallVector chunkStrides; + chunkOffsets.reserve(info.offsets.size()); + chunkSizes.reserve(info.sizes.size()); + chunkStrides.reserve(info.strides.size()); + + for (size_t dim = 0; dim < info.sizes.size(); ++dim) { + if (dim + 1 < info.sizes.size()) { + assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); + chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); + chunkSizes.push_back(rewriter.getIndexAttr(1)); + } else { + chunkOffsets.push_back(info.offsets[dim]); + chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); + } + chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); + } + + return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); +} + +static Value buildContiguousChunk(Value source, + ArrayRef copyShape, + ArrayRef outerIndices, + Location loc, + PatternRewriter& rewriter) { + SmallVector chunkOffsets; + SmallVector chunkSizes; + SmallVector chunkStrides; + chunkOffsets.reserve(copyShape.size()); + chunkSizes.reserve(copyShape.size()); + chunkStrides.reserve(copyShape.size()); + + for (size_t dim = 0; dim < copyShape.size(); ++dim) { + chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0)); + chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back())); + chunkStrides.push_back(rewriter.getIndexAttr(1)); + } + + return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides); +} + template static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, Value dst, @@ -114,6 +196,25 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, auto outerStrides = computeRowMajorStrides(outerShape); const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); + if (numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast(copyShape.size()) + && dstType.getRank() == static_cast(copyShape.size())) { + auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0); + auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices); + auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1); + + auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); + rewriter.setInsertionPointToStart(loop.getBody()); + + SmallVector outerIndices = + outerShape.empty() ? SmallVector {} : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); + Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) + : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); + Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) + : buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter); + createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes); + return success(); + } + rewriter.setInsertionPoint(copyOp); for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices =