#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" namespace onnx_mlir { mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) { if (!moduleOp || !getGlobalOp) return {}; return moduleOp.lookupSymbol(getGlobalOp.getName()); } namespace { mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) { if (!knowledge) return value; auto iter = knowledge->aliases.find(value); while (iter != knowledge->aliases.end()) { value = iter->second; iter = knowledge->aliases.find(value); } return value; } llvm::FailureOr compileIndexValueImpl(mlir::Value value); llvm::FailureOr compileContiguousAddressExprImpl(mlir::Value value); template CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) { return CompiledIndexExpr(std::make_shared(std::forward(args)...)); } mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); if (mlir::isa(value)) return value; mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return value; if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { if (auto result = mlir::dyn_cast(value)) if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result)) return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); } if (auto castOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); if (auto collapseOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge); if (auto expandOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge); return value; } llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge); static llvm::FailureOr> getStaticMemRefStrides(mlir::MemRefType type) { llvm::SmallVector strides; int64_t offset = 0; if (failed(type.getStridesAndOffset(strides, offset))) return mlir::failure(); if (llvm::any_of(strides, mlir::ShapedType::isDynamic)) return mlir::failure(); return strides; } static llvm::FailureOr resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp, const StaticValueKnowledge* knowledge) { auto getGlobalOp = loadOp.getMemRef().getDefiningOp(); if (!getGlobalOp) return mlir::failure(); auto moduleOp = loadOp->getParentOfType(); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) return mlir::failure(); auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); auto globalType = mlir::dyn_cast(getGlobalOp.getType()); if (!denseAttr || !globalType || !globalType.hasStaticShape()) return mlir::failure(); auto elementType = denseAttr.getElementType(); if (!elementType.isIndex() && !elementType.isInteger()) return mlir::failure(); llvm::SmallVector indices; indices.reserve(loadOp.getIndices().size()); for (mlir::Value index : loadOp.getIndices()) { auto resolvedIndex = resolveIndexValueImpl(index, knowledge); if (failed(resolvedIndex)) return mlir::failure(); indices.push_back(*resolvedIndex); } if (indices.size() != static_cast(globalType.getRank())) return mlir::failure(); auto strides = computeRowMajorStrides(globalType.getShape()); int64_t linearIndex = linearizeIndex(indices, strides); if (linearIndex < 0 || linearIndex >= globalType.getNumElements()) return mlir::failure(); return denseAttr.getValues()[linearIndex].getSExtValue(); } static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) { switch (predicate) { case mlir::arith::CmpIPredicate::eq: return lhs == rhs; case mlir::arith::CmpIPredicate::ne: return lhs != rhs; case mlir::arith::CmpIPredicate::slt: return lhs < rhs; case mlir::arith::CmpIPredicate::sle: return lhs <= rhs; case mlir::arith::CmpIPredicate::sgt: return lhs > rhs; case mlir::arith::CmpIPredicate::sge: return lhs >= rhs; case mlir::arith::CmpIPredicate::ult: return static_cast(lhs) < static_cast(rhs); case mlir::arith::CmpIPredicate::ule: return static_cast(lhs) <= static_cast(rhs); case mlir::arith::CmpIPredicate::ugt: return static_cast(lhs) > static_cast(rhs); case mlir::arith::CmpIPredicate::uge: return static_cast(lhs) >= static_cast(rhs); } llvm_unreachable("unknown cmpi predicate"); } llvm::FailureOr evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) { if (!expr.node) return mlir::failure(); switch (expr.node->kind) { case CompiledIndexExprNode::Kind::Constant: return expr.node->constant; case CompiledIndexExprNode::Kind::Symbol: { auto value = resolveAlias(expr.node->symbol, &knowledge); auto iter = knowledge.indexValues.find(value); if (iter != knowledge.indexValues.end()) return iter->second; return mlir::failure(); } case CompiledIndexExprNode::Kind::Add: case CompiledIndexExprNode::Kind::Sub: case CompiledIndexExprNode::Kind::Mul: case CompiledIndexExprNode::Kind::DivUI: case CompiledIndexExprNode::Kind::DivSI: case CompiledIndexExprNode::Kind::RemUI: case CompiledIndexExprNode::Kind::RemSI: case CompiledIndexExprNode::Kind::MinUI: case CompiledIndexExprNode::Kind::CmpI: { auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge); auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); switch (expr.node->kind) { case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs; case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs; case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs; case CompiledIndexExprNode::Kind::DivUI: if (*rhs == 0) return mlir::failure(); return static_cast(static_cast(*lhs) / static_cast(*rhs)); case CompiledIndexExprNode::Kind::DivSI: if (*rhs == 0 || (*lhs == std::numeric_limits::min() && *rhs == -1)) return mlir::failure(); return *lhs / *rhs; case CompiledIndexExprNode::Kind::RemUI: if (*rhs == 0) return mlir::failure(); return static_cast(static_cast(*lhs) % static_cast(*rhs)); case CompiledIndexExprNode::Kind::RemSI: if (*rhs == 0) return mlir::failure(); if (*lhs == std::numeric_limits::min() && *rhs == -1) return 0; return *lhs % *rhs; case CompiledIndexExprNode::Kind::MinUI: return static_cast(std::min(static_cast(*lhs), static_cast(*rhs))); case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0; default: llvm_unreachable("unexpected binary compiled index kind"); } } case CompiledIndexExprNode::Kind::Select: { auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge); if (failed(condition)) return mlir::failure(); return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge); } case CompiledIndexExprNode::Kind::ConstantGlobalLoad: { if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue()) return mlir::failure(); auto denseAttr = mlir::dyn_cast(*expr.node->globalOp.getInitialValue()); auto globalType = mlir::dyn_cast(expr.node->globalOp.getType()); if (!denseAttr || !globalType) return mlir::failure(); llvm::SmallVector indices; indices.reserve(expr.node->operands.size()); for (const CompiledIndexExpr& operand : expr.node->operands) { auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge); if (failed(resolvedIndex)) return mlir::failure(); indices.push_back(*resolvedIndex); } int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides); if (linearIndex < 0 || linearIndex >= globalType.getNumElements()) return mlir::failure(); return denseAttr.getValues()[linearIndex].getSExtValue(); } } llvm_unreachable("unknown compiled index kind"); } llvm::FailureOr compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) { auto getGlobalOp = loadOp.getMemRef().getDefiningOp(); if (!getGlobalOp) return mlir::failure(); auto moduleOp = loadOp->getParentOfType(); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) return mlir::failure(); auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); auto globalType = mlir::dyn_cast(getGlobalOp.getType()); if (!denseAttr || !globalType || !globalType.hasStaticShape()) return mlir::failure(); auto elementType = denseAttr.getElementType(); if (!elementType.isIndex() && !elementType.isInteger()) return mlir::failure(); CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad; expr.globalOp = globalOp; expr.globalStrides = computeRowMajorStrides(globalType.getShape()); expr.operands.reserve(loadOp.getIndices().size()); for (mlir::Value index : loadOp.getIndices()) { auto compiledIndex = compileIndexValueImpl(index); if (failed(compiledIndex)) return mlir::failure(); expr.operands.push_back(*compiledIndex); } return makeCompiledIndexExpr(std::move(expr)); } llvm::FailureOr compileIndexValueImpl(mlir::Value value) { if (auto constantOp = value.getDefiningOp()) { if (auto integerAttr = mlir::dyn_cast(constantOp.getValue())) { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = integerAttr.getInt(); return makeCompiledIndexExpr(std::move(expr)); } } mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Symbol; expr.symbol = value; return makeCompiledIndexExpr(std::move(expr)); } auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) { auto lhs = compileIndexValueImpl(lhsValue); auto rhs = compileIndexValueImpl(rhsValue); if (failed(lhs) || failed(rhs)) return llvm::FailureOr(mlir::failure()); CompiledIndexExprNode expr; expr.kind = kind; expr.operands = {*lhs, *rhs}; return llvm::FailureOr(makeCompiledIndexExpr(std::move(expr))); }; if (auto indexCastOp = mlir::dyn_cast(definingOp)) return compileIndexValueImpl(indexCastOp.getIn()); if (auto addOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs()); if (auto subOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs()); if (auto mulOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs()); if (auto divOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs()); if (auto divOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs()); if (auto minOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs()); if (auto remOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs()); if (auto remOp = mlir::dyn_cast(definingOp)) return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs()); if (auto cmpOp = mlir::dyn_cast(definingOp)) { auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs()); if (failed(expr)) return mlir::failure(); auto exprNode = std::make_shared(*expr->node); exprNode->predicate = cmpOp.getPredicate(); return CompiledIndexExpr(exprNode); } if (auto maxOp = mlir::dyn_cast(definingOp)) { auto lhs = compileIndexValueImpl(maxOp.getLhs()); auto rhs = compileIndexValueImpl(maxOp.getRhs()); if (failed(lhs) || failed(rhs)) return mlir::failure(); CompiledIndexExprNode cmpExpr; cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI; cmpExpr.predicate = mlir::arith::CmpIPredicate::uge; cmpExpr.operands = {*lhs, *rhs}; CompiledIndexExprNode selectExpr; selectExpr.kind = CompiledIndexExprNode::Kind::Select; selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs}; return makeCompiledIndexExpr(std::move(selectExpr)); } if (auto selectOp = mlir::dyn_cast(definingOp)) { auto condition = compileIndexValueImpl(selectOp.getCondition()); auto trueValue = compileIndexValueImpl(selectOp.getTrueValue()); auto falseValue = compileIndexValueImpl(selectOp.getFalseValue()); if (failed(condition) || failed(trueValue) || failed(falseValue)) return mlir::failure(); CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Select; expr.operands = {*condition, *trueValue, *falseValue}; return makeCompiledIndexExpr(std::move(expr)); } if (auto loadOp = mlir::dyn_cast(definingOp)) return compileConstantGlobalLoad(loadOp); CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Symbol; expr.symbol = value; return makeCompiledIndexExpr(std::move(expr)); } llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); if (knowledge) { auto iter = knowledge->indexValues.find(value); if (iter != knowledge->indexValues.end()) return iter->second; } auto constantOp = value.getDefiningOp(); if (constantOp) { if (auto integerAttr = mlir::dyn_cast(constantOp.getValue())) return integerAttr.getInt(); } mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return mlir::failure(); if (auto indexCastOp = mlir::dyn_cast(definingOp)) return resolveIndexValueImpl(indexCastOp.getIn(), knowledge); if (auto addOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return *lhs + *rhs; } if (auto subOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return *lhs - *rhs; } if (auto mulOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return *lhs * *rhs; } if (auto divOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return mlir::failure(); return static_cast(static_cast(*lhs) / static_cast(*rhs)); } if (auto divOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return mlir::failure(); if (*lhs == std::numeric_limits::min() && *rhs == -1) return mlir::failure(); return *lhs / *rhs; } if (auto minOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return static_cast(std::min(static_cast(*lhs), static_cast(*rhs))); } if (auto remOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return mlir::failure(); return static_cast(static_cast(*lhs) % static_cast(*rhs)); } if (auto remOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return mlir::failure(); if (*lhs == std::numeric_limits::min() && *rhs == -1) return 0; return *lhs % *rhs; } if (auto cmpOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0; } if (auto selectOp = mlir::dyn_cast(definingOp)) { auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge); if (failed(condition)) return mlir::failure(); return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge); } if (auto loadOp = mlir::dyn_cast(definingOp)) return resolveConstantGlobalLoad(loadOp, knowledge); return mlir::failure(); } llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) { if (auto attr = mlir::dyn_cast(ofr)) { auto integerAttr = mlir::dyn_cast(attr); if (!integerAttr) return mlir::failure(); return integerAttr.getInt(); } return resolveIndexValueImpl(mlir::cast(ofr), knowledge); } llvm::FailureOr resolveContiguousAddressImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { int64_t byteOffset = 0; value = resolveAlias(value, knowledge); while (true) { if (mlir::isa(value)) return ResolvedContiguousAddress {value, byteOffset}; mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return mlir::failure(); if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast(value)); if (!tiedOperand) return mlir::failure(); value = resolveAlias(tiedOperand->get(), knowledge); continue; } if (auto forOp = mlir::dyn_cast(definingOp)) { auto result = mlir::dyn_cast(value); if (!result) return mlir::failure(); auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); continue; } } value = yieldedValue; continue; } if (auto subviewOp = mlir::dyn_cast(definingOp)) { auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); auto subviewType = mlir::dyn_cast(subviewOp.getType()); if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return mlir::failure(); llvm::SmallVector offsets; llvm::SmallVector sizes; llvm::SmallVector strides; offsets.reserve(subviewOp.getMixedOffsets().size()); sizes.reserve(subviewOp.getMixedSizes().size()); strides.reserve(subviewOp.getMixedStrides().size()); for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) { auto resolvedOffset = resolveOpFoldResult(offset, knowledge); if (failed(resolvedOffset)) return mlir::failure(); offsets.push_back(*resolvedOffset); } for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) { auto resolvedSize = resolveOpFoldResult(size, knowledge); if (failed(resolvedSize)) return mlir::failure(); sizes.push_back(*resolvedSize); } for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) { auto resolvedStride = resolveOpFoldResult(stride, knowledge); if (failed(resolvedStride)) return mlir::failure(); strides.push_back(*resolvedStride); } if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) return mlir::failure(); auto sourceStrides = getStaticMemRefStrides(sourceType); if (failed(sourceStrides)) return mlir::failure(); byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); value = resolveAlias(subviewOp.getSource(), knowledge); continue; } if (auto castOp = mlir::dyn_cast(definingOp)) { value = resolveAlias(castOp.getSource(), knowledge); continue; } if (auto collapseOp = mlir::dyn_cast(definingOp)) { value = resolveAlias(collapseOp.getSrc(), knowledge); continue; } if (auto expandOp = mlir::dyn_cast(definingOp)) { value = resolveAlias(expandOp.getSrc(), knowledge); continue; } if (mlir::isa(definingOp)) return ResolvedContiguousAddress {value, byteOffset}; return mlir::failure(); } } llvm::FailureOr compileContiguousAddressExprImpl(mlir::Value value) { int64_t constantByteOffset = 0; CompiledIndexExpr byteOffsetExpr; { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = 0; byteOffsetExpr = makeCompiledIndexExpr(std::move(expr)); } while (true) { if (mlir::isa(value)) return CompiledAddressExpr {value, byteOffsetExpr}; mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return mlir::failure(); if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast(value)); if (!tiedOperand) return mlir::failure(); value = tiedOperand->get(); continue; } if (auto forOp = mlir::dyn_cast(definingOp)) { auto result = mlir::dyn_cast(value); if (!result) return mlir::failure(); auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber()); if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1]; continue; } } value = yieldedValue; continue; } if (auto subviewOp = mlir::dyn_cast(definingOp)) { auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); auto subviewType = mlir::dyn_cast(subviewOp.getType()); if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return mlir::failure(); llvm::SmallVector staticSizes; staticSizes.reserve(subviewOp.getMixedSizes().size()); llvm::SmallVector staticStrides; staticStrides.reserve(subviewOp.getMixedStrides().size()); llvm::SmallVector staticOffsets; staticOffsets.reserve(subviewOp.getMixedOffsets().size()); bool hasOnlyStaticOffsets = true; for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) if (auto attr = mlir::dyn_cast(offset)) staticOffsets.push_back(mlir::cast(attr).getInt()); else hasOnlyStaticOffsets = false; for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) { auto attr = mlir::dyn_cast(size); if (!attr) return mlir::failure(); staticSizes.push_back(mlir::cast(attr).getInt()); } for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) { auto attr = mlir::dyn_cast(stride); if (!attr) return mlir::failure(); staticStrides.push_back(mlir::cast(attr).getInt()); } if (!isContiguousSubviewWithDynamicOffsets( sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) { return mlir::failure(); } if (hasOnlyStaticOffsets) { if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides)) return mlir::failure(); auto sourceStrides = getStaticMemRefStrides(sourceType); if (failed(sourceStrides)) return mlir::failure(); constantByteOffset += linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); } else { auto sourceStrides = getStaticMemRefStrides(sourceType); if (failed(sourceStrides)) return mlir::failure(); CompiledIndexExpr offsetExpr; { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = 0; offsetExpr = makeCompiledIndexExpr(std::move(expr)); } for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) { CompiledIndexExpr operandExpr; if (auto attr = mlir::dyn_cast(mixedOffset)) { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = mlir::cast(attr).getInt() * sourceStride * getElementTypeSizeInBytes(subviewType.getElementType()); operandExpr = makeCompiledIndexExpr(std::move(expr)); } else { auto compiledOffset = compileIndexValueImpl(mlir::cast(mixedOffset)); if (failed(compiledOffset)) return mlir::failure(); CompiledIndexExpr scaleExpr; { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType()); scaleExpr = makeCompiledIndexExpr(std::move(expr)); } CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Mul; expr.operands = {*compiledOffset, scaleExpr}; operandExpr = makeCompiledIndexExpr(std::move(expr)); } CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Add; expr.operands = {offsetExpr, operandExpr}; offsetExpr = makeCompiledIndexExpr(std::move(expr)); } CompiledIndexExpr constantExpr; { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = constantByteOffset; constantExpr = makeCompiledIndexExpr(std::move(expr)); } CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Add; expr.operands = {constantExpr, offsetExpr}; byteOffsetExpr = makeCompiledIndexExpr(std::move(expr)); constantByteOffset = 0; } value = subviewOp.getSource(); continue; } if (auto castOp = mlir::dyn_cast(definingOp)) { value = castOp.getSource(); continue; } if (auto collapseOp = mlir::dyn_cast(definingOp)) { value = collapseOp.getSrc(); continue; } if (auto expandOp = mlir::dyn_cast(definingOp)) { value = expandOp.getSrc(); continue; } if (mlir::isa(definingOp)) { if (constantByteOffset != 0) { CompiledIndexExpr constantExpr; { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = constantByteOffset; constantExpr = makeCompiledIndexExpr(std::move(expr)); } if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0) byteOffsetExpr = constantExpr; else { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Add; expr.operands = {constantExpr, byteOffsetExpr}; byteOffsetExpr = makeCompiledIndexExpr(std::move(expr)); } } return CompiledAddressExpr {value, byteOffsetExpr}; } return mlir::failure(); } } } // namespace llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveIndexValueImpl(value, &knowledge); } llvm::FailureOr compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); } llvm::FailureOr resolveContiguousAddress(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveContiguousAddressImpl(value, &knowledge); } mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveLoopCarriedAliasImpl(value, &knowledge); } llvm::FailureOr compileContiguousAddressExpr(mlir::Value value) { return compileContiguousAddressExprImpl(value); } llvm::FailureOr CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const { return evaluateCompiledIndexExpr(*this, knowledge); } llvm::FailureOr CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional lane) const { (void) lane; auto resolvedOffset = byteOffset.evaluate(knowledge); if (failed(resolvedOffset)) return mlir::failure(); return ResolvedContiguousAddress {base, *resolvedOffset}; } } // namespace onnx_mlir