42 lines
1.4 KiB
C++
42 lines
1.4 KiB
C++
#pragma once
|
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
struct StaticSubviewInfo {
|
|
mlir::Value source;
|
|
llvm::SmallVector<int64_t> sourceShape;
|
|
llvm::SmallVector<int64_t> offsets;
|
|
llvm::SmallVector<int64_t> sizes;
|
|
llvm::SmallVector<int64_t> strides;
|
|
};
|
|
|
|
mlir::Value stripMemRefCasts(mlir::Value value);
|
|
|
|
mlir::Value stripMemRefViewOps(mlir::Value value);
|
|
|
|
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
|
mlir::Location loc,
|
|
mlir::MemRefType globalType,
|
|
mlir::DenseElementsAttr denseAttr,
|
|
llvm::StringRef nameStem,
|
|
mlir::IntegerAttr alignment = {});
|
|
|
|
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
|
|
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
|
|
|
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
|
llvm::ArrayRef<int64_t> outerIndices,
|
|
int64_t elementByteWidth);
|
|
|
|
} // namespace onnx_mlir
|