ff36729140
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
72 lines
2.4 KiB
C++
72 lines
2.4 KiB
C++
#include "mlir/IR/ValueRange.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
#include <cassert>
|
|
|
|
#include "Common.hpp"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
|
}
|
|
|
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
|
auto users = value.getUsers();
|
|
|
|
assert(!users.empty());
|
|
|
|
Operation* earliestUser = *users.begin();
|
|
for (auto curUser : users)
|
|
if (curUser->isBeforeInBlock(earliestUser))
|
|
earliestUser = curUser;
|
|
|
|
return earliestUser;
|
|
}
|
|
|
|
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
|
auto operandsAndUses =
|
|
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
|
|
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
|
});
|
|
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
|
}
|
|
|
|
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
|
for (Operation* user : value.getUsers()) {
|
|
if (user->getBlock() != operation->getBlock())
|
|
return true;
|
|
if (operation->isBeforeInBlock(user))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
|
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
|
mlir::Value result = operation->getResult(0);
|
|
auto resultType = result.getType();
|
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
|
|
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
|
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
|
|
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
|
|
});
|
|
auto bestOperand = validOperands.begin();
|
|
|
|
if (bestOperand != validOperands.end())
|
|
return *bestOperand;
|
|
|
|
auto resultShapedType = cast<ShapedType>(resultType);
|
|
rewriter.setInsertionPoint(operation);
|
|
return tensor::EmptyOp::create(
|
|
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|