compact memory contiguity with for loops
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-31 18:47:59 +02:00
parent ab63498f3f
commit b678e55d3c
14 changed files with 550 additions and 331 deletions
+22 -6
View File
@@ -69,6 +69,16 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
return mlir::failure();
return strides;
}
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
@@ -539,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
@@ -651,12 +663,16 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
constantByteOffset +=
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
@@ -665,7 +681,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;