#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" namespace onnx_mlir { mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) { if (!moduleOp || !getGlobalOp) return {}; return moduleOp.lookupSymbol(getGlobalOp.getName()); } namespace { mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) { if (!knowledge) return value; auto iter = knowledge->aliases.find(value); while (iter != knowledge->aliases.end()) { value = iter->second; iter = knowledge->aliases.find(value); } return value; } mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); if (mlir::isa(value)) return value; mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return value; if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { if (auto result = mlir::dyn_cast(value)) if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result)) return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); } if (auto castOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); if (auto collapseOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge); if (auto expandOp = mlir::dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge); return value; } llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); if (knowledge) { auto iter = knowledge->indexValues.find(value); if (iter != knowledge->indexValues.end()) return iter->second; } auto constantOp = value.getDefiningOp(); if (constantOp) { if (auto integerAttr = mlir::dyn_cast(constantOp.getValue())) return integerAttr.getInt(); } mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return mlir::failure(); if (auto indexCastOp = mlir::dyn_cast(definingOp)) return resolveIndexValueImpl(indexCastOp.getIn(), knowledge); if (auto addOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return *lhs + *rhs; } if (auto subOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return *lhs - *rhs; } if (auto mulOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return mlir::failure(); return *lhs * *rhs; } if (auto divOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return mlir::failure(); return static_cast(static_cast(*lhs) / static_cast(*rhs)); } if (auto remOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return mlir::failure(); return static_cast(static_cast(*lhs) % static_cast(*rhs)); } return mlir::failure(); } llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) { if (auto attr = mlir::dyn_cast(ofr)) { auto integerAttr = mlir::dyn_cast(attr); if (!integerAttr) return mlir::failure(); return integerAttr.getInt(); } return resolveIndexValueImpl(mlir::cast(ofr), knowledge); } llvm::FailureOr resolveContiguousAddressImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { int64_t byteOffset = 0; value = resolveAlias(value, knowledge); while (true) { if (mlir::isa(value)) return ResolvedContiguousAddress {value, byteOffset}; mlir::Operation* definingOp = value.getDefiningOp(); if (!definingOp) return mlir::failure(); if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast(value)); if (!tiedOperand) return mlir::failure(); value = resolveAlias(tiedOperand->get(), knowledge); continue; } if (auto forOp = mlir::dyn_cast(definingOp)) { auto result = mlir::dyn_cast(value); if (!result) return mlir::failure(); auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); continue; } } value = yieldedValue; continue; } if (auto subviewOp = mlir::dyn_cast(definingOp)) { auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); auto subviewType = mlir::dyn_cast(subviewOp.getType()); if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return mlir::failure(); llvm::SmallVector offsets; llvm::SmallVector sizes; llvm::SmallVector strides; offsets.reserve(subviewOp.getMixedOffsets().size()); sizes.reserve(subviewOp.getMixedSizes().size()); strides.reserve(subviewOp.getMixedStrides().size()); for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) { auto resolvedOffset = resolveOpFoldResult(offset, knowledge); if (failed(resolvedOffset)) return mlir::failure(); offsets.push_back(*resolvedOffset); } for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) { auto resolvedSize = resolveOpFoldResult(size, knowledge); if (failed(resolvedSize)) return mlir::failure(); sizes.push_back(*resolvedSize); } for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) { auto resolvedStride = resolveOpFoldResult(stride, knowledge); if (failed(resolvedStride)) return mlir::failure(); strides.push_back(*resolvedStride); } if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) return mlir::failure(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; value = resolveAlias(subviewOp.getSource(), knowledge); continue; } if (auto castOp = mlir::dyn_cast(definingOp)) { value = resolveAlias(castOp.getSource(), knowledge); continue; } if (auto collapseOp = mlir::dyn_cast(definingOp)) { value = resolveAlias(collapseOp.getSrc(), knowledge); continue; } if (auto expandOp = mlir::dyn_cast(definingOp)) { value = resolveAlias(expandOp.getSrc(), knowledge); continue; } if (mlir::isa(definingOp)) return ResolvedContiguousAddress {value, byteOffset}; return mlir::failure(); } } } // namespace llvm::FailureOr resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); } llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveIndexValueImpl(value, &knowledge); } llvm::FailureOr resolveContiguousAddress(mlir::Value value) { return resolveContiguousAddressImpl(value, nullptr); } llvm::FailureOr resolveContiguousAddress(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveContiguousAddressImpl(value, &knowledge); } mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveLoopCarriedAliasImpl(value, &knowledge); } } // namespace onnx_mlir