centralize logic for materializing contiguous memory into bufferization

fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
This commit is contained in:
NiccoloN
2026-05-30 15:54:24 +02:00
parent 2d5b03c08f
commit ff36729140
29 changed files with 642 additions and 822 deletions
+22 -15
View File
@@ -616,31 +616,38 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true;
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
bool hasOnlyStaticOffsets = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
hasOnlyStaticOffsets = false;
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
if (!attr)
return mlir::failure();
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
if (!attr)
return mlir::failure();
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
if (allStatic) {
if (!isContiguousSubviewWithDynamicOffsets(
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
return mlir::failure();
}
if (hasOnlyStaticOffsets) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
-2
View File
@@ -20,8 +20,6 @@ llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds,
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (mlir::isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
+52
View File
@@ -111,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
return true;
}
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides) {
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|| sourceShape.size() != staticStrides.size()) {
return false;
}
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
return false;
auto reversedTriples =
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
auto [_sourceDim, offset, _size] = triple;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
return true;
});
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
if (size > sourceDim - staticOffset)
return false;
}
++firstNonZeroOrDynamicOffset;
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
if (std::get<2>(*it) != 1)
return false;
}
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
auto [sourceDim, size] = pair;
return size != sourceDim;
});
if (firstDifferentSize != reversedSizes.end()) {
++firstDifferentSize;
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
if (std::get<1>(*it) != 1)
return false;
}
return true;
}
} // namespace onnx_mlir
+6
View File
@@ -1,6 +1,7 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
@@ -30,4 +31,9 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides);
} // namespace onnx_mlir
+22
View File
@@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
}
}
Value stripMemRefAddressingOps(Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
Value strippedValue = stripMemRefViewOps(value);
if (strippedValue == value)
return value;
value = strippedValue;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
@@ -81,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
return staticOffsets;
}
bool isMemRefBaseAddressableValue(Value value) {
value = stripMemRefAddressingOps(value);
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
}
} // namespace onnx_mlir
+4
View File
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::Value stripMemRefAddressingOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
bool isMemRefBaseAddressableValue(mlir::Value value);
} // namespace onnx_mlir
+1 -23
View File
@@ -47,28 +47,6 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
}
mlir::Value stripWeightViewOps(mlir::Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<mlir::memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = value.getDefiningOp<mlir::memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<mlir::memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<mlir::memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
@@ -159,7 +137,7 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
weight = stripWeightViewOps(weight);
weight = stripMemRefAddressingOps(weight);
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)