fix instructions explosion in pim host constant folding pass
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -66,6 +66,88 @@ static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
delinearizeIndexValue(Value linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides, PatternRewriter& rewriter) {
|
||||
SmallVector<Value> indices;
|
||||
indices.reserve(shape.size());
|
||||
|
||||
Value remaining = linearIndex;
|
||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||
if (stride == 1) {
|
||||
indices.push_back(remaining);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
|
||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
indices.push_back(index);
|
||||
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
}
|
||||
|
||||
return indices;
|
||||
}
|
||||
|
||||
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = cast<IntegerAttr>(attr);
|
||||
if (integerAttr.getInt() == 0)
|
||||
return extraOffset;
|
||||
|
||||
auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt());
|
||||
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
||||
}
|
||||
|
||||
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<Value> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
if (dim + 1 < info.sizes.size()) {
|
||||
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
||||
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
||||
} else {
|
||||
chunkOffsets.push_back(info.offsets[dim]);
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
||||
}
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static Value buildContiguousChunk(Value source,
|
||||
ArrayRef<int64_t> copyShape,
|
||||
ArrayRef<Value> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(copyShape.size());
|
||||
chunkSizes.reserve(copyShape.size());
|
||||
chunkStrides.reserve(copyShape.size());
|
||||
|
||||
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
||||
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
@@ -114,6 +196,25 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
if (numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
|
||||
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
|
||||
auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1);
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
SmallVector<Value> outerIndices =
|
||||
outerShape.empty() ? SmallVector<Value> {} : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
||||
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
|
||||
Reference in New Issue
Block a user