faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-25 15:24:12 +02:00
parent 4855a2e105
commit e8a08f6dd0
18 changed files with 1610 additions and 573 deletions
+431
View File
@@ -32,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
return value;
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -128,6 +136,225 @@ static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t l
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant:
return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
return iter->second;
return mlir::failure();
}
case CompiledIndexExprNode::Kind::Add:
case CompiledIndexExprNode::Kind::Sub:
case CompiledIndexExprNode::Kind::Mul:
case CompiledIndexExprNode::Kind::DivUI:
case CompiledIndexExprNode::Kind::DivSI:
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add:
return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub:
return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul:
return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::DivSI:
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
return mlir::failure();
return *lhs / *rhs;
case CompiledIndexExprNode::Kind::RemUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::RemSI:
if (*rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI:
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default:
llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
if (failed(condition))
return mlir::failure();
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
}
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
if (!denseAttr || !globalType)
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(expr.node->operands.size());
for (const CompiledIndexExpr& operand : expr.node->operands) {
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
}
llvm_unreachable("unknown compiled index kind");
}
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
expr.globalOp = globalOp;
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
expr.operands.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto compiledIndex = compileIndexValueImpl(index);
if (failed(compiledIndex))
return mlir::failure();
expr.operands.push_back(*compiledIndex);
}
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = integerAttr.getInt();
return makeCompiledIndexExpr(std::move(expr));
}
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
auto lhs = compileIndexValueImpl(lhsValue);
auto rhs = compileIndexValueImpl(rhsValue);
if (failed(lhs) || failed(rhs))
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {*lhs, *rhs};
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
};
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return compileIndexValueImpl(indexCastOp.getIn());
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
if (failed(expr))
return mlir::failure();
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
exprNode->predicate = cmpOp.getPredicate();
return CompiledIndexExpr(exprNode);
}
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
auto lhs = compileIndexValueImpl(maxOp.getLhs());
auto rhs = compileIndexValueImpl(maxOp.getRhs());
if (failed(lhs) || failed(rhs))
return mlir::failure();
CompiledIndexExprNode cmpExpr;
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
cmpExpr.operands = {*lhs, *rhs};
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
return makeCompiledIndexExpr(std::move(selectExpr));
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = compileIndexValueImpl(selectOp.getCondition());
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
if (failed(condition) || failed(trueValue) || failed(falseValue))
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Select;
expr.operands = {*condition, *trueValue, *falseValue};
return makeCompiledIndexExpr(std::move(expr));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return compileConstantGlobalLoad(loadOp);
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -353,6 +580,191 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
}
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
int64_t constantByteOffset = 0;
CompiledIndexExpr byteOffsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return CompiledAddressExpr {value, byteOffsetExpr};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = tiedOperand->get();
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
if (allStatic) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
constantByteOffset +=
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
* getElementTypeSizeInBytes(subviewType.getElementType());
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
else {
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
if (failed(compiledOffset))
return mlir::failure();
CompiledIndexExpr scaleExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
scaleExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Mul;
expr.operands = {*compiledOffset, scaleExpr};
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {offsetExpr, operandExpr};
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, offsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
constantByteOffset = 0;
}
value = subviewOp.getSource();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
if (constantByteOffset != 0) {
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
byteOffsetExpr = constantExpr;
else {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, byteOffsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
}
return CompiledAddressExpr {value, byteOffsetExpr};
}
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
@@ -361,6 +773,8 @@ llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueK
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
@@ -374,4 +788,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
return compileContiguousAddressExprImpl(value);
}
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress>
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset};
}
} // namespace onnx_mlir