This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
+61
View File
@@ -56,6 +56,22 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (result) {
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
}
return yieldedValue;
}
}
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
@@ -512,6 +528,24 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
continue;
}
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto condition = resolveIndexValueImpl(ifOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
mlir::Region& selectedRegion = *condition != 0 ? ifOp.getThenRegion() : ifOp.getElseRegion();
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(selectedRegion.front().getTerminator());
if (!yieldOp || result.getResultNumber() >= yieldOp.getNumOperands())
return mlir::failure();
value = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
@@ -622,6 +656,33 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
continue;
}
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto thenYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto elseYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
if (!thenYield || !elseYield || result.getResultNumber() >= thenYield.getNumOperands()
|| result.getResultNumber() >= elseYield.getNumOperands()) {
return mlir::failure();
}
auto thenAddress = compileContiguousAddressExprImpl(thenYield.getOperand(result.getResultNumber()));
auto elseAddress = compileContiguousAddressExprImpl(elseYield.getOperand(result.getResultNumber()));
if (failed(thenAddress) || failed(elseAddress) || thenAddress->base != elseAddress->base)
return mlir::failure();
auto condition = compileIndexValueImpl(ifOp.getCondition());
if (failed(condition))
return mlir::failure();
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {*condition, thenAddress->byteOffset, elseAddress->byteOffset};
return CompiledAddressExpr {thenAddress->base, makeCompiledIndexExpr(std::move(selectExpr))};
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());