compact memory contiguity with for loops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -69,6 +69,16 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return mlir::failure();
|
||||
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
|
||||
return mlir::failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
@@ -539,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
@@ -651,12 +663,16 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
constantByteOffset +=
|
||||
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
}
|
||||
else {
|
||||
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr offsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
@@ -665,7 +681,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||
CompiledIndexExpr operandExpr;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
||||
CompiledIndexExprNode expr;
|
||||
|
||||
Reference in New Issue
Block a user