fix missed failing tests for channels
moderate refactor
This commit is contained in:
@@ -0,0 +1,299 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
|
||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
|
||||
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
if (firstDifferentSize == sizesAndShape.end())
|
||||
return true;
|
||||
|
||||
++firstDifferentSize;
|
||||
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size == 1;
|
||||
});
|
||||
}
|
||||
|
||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
||||
if (extraOffset == 0)
|
||||
return baseOffset;
|
||||
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(integerAttr && "expected integer offset attribute");
|
||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
||||
}
|
||||
|
||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstOffset,
|
||||
int64_t srcOffset,
|
||||
int64_t size,
|
||||
PatternRewriter& rewriter,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
|
||||
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (size != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
copyOp.getTargetOffset(),
|
||||
copyOp.getSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getDeviceTarget(),
|
||||
copyOp.getHostSource(),
|
||||
copyOp.getDeviceTargetOffset(),
|
||||
copyOp.getHostSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getHostTarget(),
|
||||
copyOp.getDeviceSource(),
|
||||
copyOp.getHostTargetOffset(),
|
||||
copyOp.getDeviceSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType =
|
||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal =
|
||||
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal =
|
||||
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<RewriteCoreSubviewCopyPattern,
|
||||
RewriteHostSubviewLoadPattern,
|
||||
RewriteHostSubviewStorePattern,
|
||||
FoldConstantCoreSubviewPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user