fix failing validations after last commit
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-13 17:46:19 +02:00
parent 324178cba8
commit ea61540e08
2 changed files with 16 additions and 8 deletions
@@ -72,12 +72,7 @@ delinearizeIndexValue(Value linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64
indices.reserve(shape.size());
Value remaining = linearIndex;
for (auto [dim, stride] : llvm::enumerate(strides)) {
if (stride == 1) {
indices.push_back(remaining);
continue;
}
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index);
@@ -155,6 +150,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
@@ -196,7 +192,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
if (numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast<int64_t>(copyShape.size())
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
@@ -244,6 +241,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
@@ -276,6 +274,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
@@ -308,6 +307,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
copyOp.getHostTargetOffset(),
copyOp.getDeviceSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
+9 -1
View File
@@ -67,6 +67,14 @@ static bool isCodegenAddressableValue(Value value) {
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (failed(resolvedAddress))
return false;
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) {
while (true) {
Operation* defOp = value.getDefiningOp();
@@ -260,7 +268,7 @@ private:
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
if (!isCodegenAddressableValue(operand, knowledge)) {
op.emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
hasFailure = true;