#include "mlir/IR/ValueRange.h" #include "llvm/ADT/STLExtras.h" #include #include "Common.hpp" using namespace llvm; using namespace mlir; namespace onnx_mlir { IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { return builder.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(cast(value.getType())))); } Operation* getEarliestUserWithinBlock(mlir::Value value) { auto users = value.getUsers(); assert(!users.empty()); Operation* earliestUser = *users.begin(); for (auto curUser : users) if (curUser->isBeforeInBlock(earliestUser)) earliestUser = curUser; return earliestUser; } SmallVector getOpOperandsSortedByUses(Operation* operation) { auto operandsAndUses = map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair { return {operand, std::distance(operand.use_begin(), operand.use_end())}; }); sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; }); return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); } bool hasLaterUserInBlock(mlir::Value value, Operation* operation) { for (Operation* user : value.getUsers()) { if (user->getBlock() != operation->getBlock()) return true; if (operation->isBeforeInBlock(user)) return true; } return false; } mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) { assert("Only support operations with a single result" && operation->getNumResults() == 1); mlir::Value result = operation->getResult(0); auto resultType = result.getType(); assert("Only support result ShapedType as result type" && isa(resultType)); SmallVector operands = getOpOperandsSortedByUses(operation); auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) { return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation); }); auto bestOperand = validOperands.begin(); if (bestOperand != validOperands.end()) return *bestOperand; auto resultShapedType = cast(resultType); rewriter.setInsertionPoint(operation); return tensor::EmptyOp::create( rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); } } // namespace onnx_mlir