389 lines
17 KiB
C++
389 lines
17 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
|
|
#include "../Common.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
|
|
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
|
return false;
|
|
|
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
|
|
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
|
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
auto [size, dimension] = sizeAndShape;
|
|
return size != dimension;
|
|
});
|
|
if (firstDifferentSize == sizesAndShape.end())
|
|
return true;
|
|
|
|
++firstDifferentSize;
|
|
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
|
auto [size, _dimension] = sizeAndShape;
|
|
return size == 1;
|
|
});
|
|
}
|
|
|
|
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
|
if (extraOffset == 0)
|
|
return baseOffset;
|
|
|
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
|
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
|
assert(integerAttr && "expected integer offset attribute");
|
|
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
|
}
|
|
|
|
auto value = cast<Value>(baseOffset);
|
|
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset);
|
|
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
|
}
|
|
|
|
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
|
ArrayRef<int64_t> outerIndices,
|
|
Location loc,
|
|
PatternRewriter& rewriter) {
|
|
SmallVector<OpFoldResult> chunkOffsets;
|
|
SmallVector<OpFoldResult> chunkSizes;
|
|
SmallVector<OpFoldResult> chunkStrides;
|
|
chunkOffsets.reserve(info.offsets.size());
|
|
chunkSizes.reserve(info.sizes.size());
|
|
chunkStrides.reserve(info.strides.size());
|
|
|
|
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
|
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
|
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
|
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
|
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
|
}
|
|
|
|
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
|
}
|
|
|
|
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
|
ArrayRef<int64_t> shape,
|
|
ArrayRef<int64_t> strides,
|
|
PatternRewriter& rewriter) {
|
|
SmallVector<Value> indices;
|
|
indices.reserve(shape.size());
|
|
|
|
Value remaining = linearIndex;
|
|
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
|
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
|
|
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
|
indices.push_back(index);
|
|
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
|
}
|
|
|
|
return indices;
|
|
}
|
|
|
|
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
|
auto integerAttr = cast<IntegerAttr>(attr);
|
|
if (integerAttr.getInt() == 0)
|
|
return extraOffset;
|
|
|
|
auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt());
|
|
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
|
}
|
|
|
|
auto value = cast<Value>(baseOffset);
|
|
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
|
}
|
|
|
|
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
|
ArrayRef<Value> outerIndices,
|
|
Location loc,
|
|
PatternRewriter& rewriter) {
|
|
SmallVector<OpFoldResult> chunkOffsets;
|
|
SmallVector<OpFoldResult> chunkSizes;
|
|
SmallVector<OpFoldResult> chunkStrides;
|
|
chunkOffsets.reserve(info.offsets.size());
|
|
chunkSizes.reserve(info.sizes.size());
|
|
chunkStrides.reserve(info.strides.size());
|
|
|
|
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
|
if (dim + 1 < info.sizes.size()) {
|
|
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
|
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
|
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
else {
|
|
chunkOffsets.push_back(info.offsets[dim]);
|
|
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
|
}
|
|
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
|
}
|
|
|
|
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
|
}
|
|
|
|
static Value buildContiguousChunk(
|
|
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
|
SmallVector<OpFoldResult> chunkOffsets;
|
|
SmallVector<OpFoldResult> chunkSizes;
|
|
SmallVector<OpFoldResult> chunkStrides;
|
|
chunkOffsets.reserve(copyShape.size());
|
|
chunkSizes.reserve(copyShape.size());
|
|
chunkStrides.reserve(copyShape.size());
|
|
|
|
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
|
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
|
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
|
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
|
}
|
|
|
|
template <typename CopyOp, typename CreateCopyOp>
|
|
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|
Value dst,
|
|
Value src,
|
|
int64_t dstOffset,
|
|
int64_t srcOffset,
|
|
int64_t size,
|
|
bool allowLoopRewrite,
|
|
PatternRewriter& rewriter,
|
|
CreateCopyOp createCopyOp) {
|
|
auto srcSubview = getStaticSubviewInfo(src);
|
|
auto dstSubview = getStaticSubviewInfo(dst);
|
|
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
|
|
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
|
|
if (!splitSrc && !splitDst)
|
|
return failure();
|
|
|
|
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
|
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
return failure();
|
|
if (sourceType.getElementType() != dstType.getElementType())
|
|
return failure();
|
|
|
|
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
|
return failure();
|
|
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
|
return failure();
|
|
|
|
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
|
if (elementByteWidth <= 0)
|
|
return failure();
|
|
|
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
if (size != totalBytes)
|
|
return failure();
|
|
|
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
if (sliceBytes <= 0)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
|
|
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
|
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
|
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
|
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
|
|
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
|
|
auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1);
|
|
|
|
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
|
rewriter.setInsertionPointToStart(loop.getBody());
|
|
|
|
SmallVector<Value> outerIndices =
|
|
outerShape.empty() ? SmallVector<Value> {}
|
|
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
|
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
|
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
|
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
|
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
|
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
|
return success();
|
|
}
|
|
|
|
rewriter.setInsertionPoint(copyOp);
|
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
|
SmallVector<int64_t> outerIndices =
|
|
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
|
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
|
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
|
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
|
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
|
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
// Splits core copies through subviews into contiguous copy chunks for codegen.
|
|
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
|
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
|
return failure();
|
|
|
|
auto status = rewriteSubviewCopyLikeOp(
|
|
copyOp,
|
|
copyOp.getTarget(),
|
|
copyOp.getSource(),
|
|
copyOp.getTargetOffset(),
|
|
copyOp.getSourceOffset(),
|
|
copyOp.getSize(),
|
|
/*allowLoopRewrite=*/true,
|
|
rewriter,
|
|
[&](
|
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
|
pim::PimMemCopyOp::create(rewriter,
|
|
copyOp.getLoc(),
|
|
resultType,
|
|
dst,
|
|
src,
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
});
|
|
if (failed(status))
|
|
return failure();
|
|
|
|
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Splits host-to-device subview loads into contiguous copy chunks.
|
|
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
|
auto status = rewriteSubviewCopyLikeOp(
|
|
copyOp,
|
|
copyOp.getDeviceTarget(),
|
|
copyOp.getHostSource(),
|
|
copyOp.getDeviceTargetOffset(),
|
|
copyOp.getHostSourceOffset(),
|
|
copyOp.getSize(),
|
|
/*allowLoopRewrite=*/true,
|
|
rewriter,
|
|
[&](
|
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
|
copyOp.getLoc(),
|
|
resultType,
|
|
dst,
|
|
src,
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
});
|
|
if (failed(status))
|
|
return failure();
|
|
|
|
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Splits device-to-host subview stores into contiguous copy chunks.
|
|
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
|
auto status = rewriteSubviewCopyLikeOp(
|
|
copyOp,
|
|
copyOp.getHostTarget(),
|
|
copyOp.getDeviceSource(),
|
|
copyOp.getHostTargetOffset(),
|
|
copyOp.getDeviceSourceOffset(),
|
|
copyOp.getSize(),
|
|
/*allowLoopRewrite=*/false,
|
|
rewriter,
|
|
[&](
|
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
|
copyOp.getLoc(),
|
|
resultType,
|
|
dst,
|
|
src,
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
});
|
|
if (failed(status))
|
|
return failure();
|
|
|
|
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Folds constant subviews used as core weights into standalone globals.
|
|
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
|
if (subviewOp.use_empty())
|
|
return failure();
|
|
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
|
return failure();
|
|
|
|
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
|
if (!moduleOp)
|
|
return failure();
|
|
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
|
if (failed(denseAttr))
|
|
return failure();
|
|
|
|
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
|
if (failed(subviewInfo))
|
|
return failure();
|
|
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
|
return failure();
|
|
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
|
|
if (failed(staticOffsets))
|
|
return failure();
|
|
|
|
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
|
auto resultMemRefType =
|
|
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
|
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
|
if (failed(foldedAttr))
|
|
return failure();
|
|
|
|
auto newGlobal =
|
|
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview");
|
|
markWeightAlways(newGlobal);
|
|
|
|
rewriter.setInsertionPoint(subviewOp);
|
|
auto newGetGlobal =
|
|
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
|
markWeightAlways(newGetGlobal);
|
|
|
|
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
|
patterns.add<RewriteCoreSubviewCopyPattern,
|
|
RewriteHostSubviewLoadPattern,
|
|
RewriteHostSubviewStorePattern,
|
|
FoldConstantCoreSubviewPattern>(patterns.getContext());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|