refactor Pim constant folding pass

share contiguous address resolution in PimCommon
group patterns in subdir for each pass with pattern files
This commit is contained in:
NiccoloN
2026-03-23 15:36:58 +01:00
parent 670d6ce94f
commit 11916a2595
32 changed files with 616 additions and 516 deletions

View File

@@ -1,3 +1,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
@@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
return true;
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
int64_t byteOffset = 0;
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress{value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = tiedOperand->get();
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|| llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress{value, byteOffset};
return failure();
}
}
} // namespace onnx_mlir

View File

@@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
@@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
} // namespace onnx_mlir