fix failing validations after last commit
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-13 17:46:19 +02:00
parent 324178cba8
commit ea61540e08
2 changed files with 16 additions and 8 deletions
@@ -72,12 +72,7 @@ delinearizeIndexValue(Value linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64
indices.reserve(shape.size());
Value remaining = linearIndex;
for (auto [dim, stride] : llvm::enumerate(strides)) {
if (stride == 1) {
indices.push_back(remaining);
continue;
}
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index);
@@ -155,6 +150,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
@@ -196,7 +192,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
if (numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast<int64_t>(copyShape.size())
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
@@ -244,6 +241,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
@@ -276,6 +274,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
@@ -308,6 +307,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
copyOp.getHostTargetOffset(),
copyOp.getDeviceSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {