centralize logic for materializing contiguous memory into bufferization
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
This commit is contained in:
@@ -616,31 +616,38 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
llvm::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
||||
llvm::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||
bool allStatic = true;
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
bool hasOnlyStaticOffsets = true;
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
|
||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
|
||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
hasOnlyStaticOffsets = false;
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
|
||||
if (!attr)
|
||||
return mlir::failure();
|
||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
}
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
|
||||
if (!attr)
|
||||
return mlir::failure();
|
||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
}
|
||||
|
||||
if (allStatic) {
|
||||
if (!isContiguousSubviewWithDynamicOffsets(
|
||||
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (hasOnlyStaticOffsets) {
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
return mlir::failure();
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@ llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds,
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (mlir::isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
|
||||
@@ -111,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides) {
|
||||
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|
||||
|| sourceShape.size() != staticStrides.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto reversedTriples =
|
||||
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
|
||||
|
||||
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
|
||||
auto [_sourceDim, offset, _size] = triple;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
|
||||
return true;
|
||||
});
|
||||
|
||||
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
|
||||
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
|
||||
if (size > sourceDim - staticOffset)
|
||||
return false;
|
||||
}
|
||||
|
||||
++firstNonZeroOrDynamicOffset;
|
||||
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
|
||||
if (std::get<2>(*it) != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
|
||||
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
|
||||
auto [sourceDim, size] = pair;
|
||||
return size != sourceDim;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != reversedSizes.end()) {
|
||||
++firstDifferentSize;
|
||||
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
|
||||
if (std::get<1>(*it) != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
@@ -30,4 +31,9 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
|
||||
}
|
||||
}
|
||||
|
||||
Value stripMemRefAddressingOps(Value value) {
|
||||
while (true) {
|
||||
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
Value strippedValue = stripMemRefViewOps(value);
|
||||
if (strippedValue == value)
|
||||
return value;
|
||||
value = strippedValue;
|
||||
}
|
||||
}
|
||||
|
||||
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
@@ -81,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
bool isMemRefBaseAddressableValue(Value value) {
|
||||
value = stripMemRefAddressingOps(value);
|
||||
if (isa<BlockArgument>(value))
|
||||
return true;
|
||||
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefAddressingOps(mlir::Value value);
|
||||
|
||||
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
bool isMemRefBaseAddressableValue(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -47,28 +47,6 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||
}
|
||||
|
||||
mlir::Value stripWeightViewOps(mlir::Value value) {
|
||||
while (true) {
|
||||
if (auto subviewOp = value.getDefiningOp<mlir::memref::SubViewOp>()) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = value.getDefiningOp<mlir::memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<mlir::memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<mlir::memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
@@ -159,7 +137,7 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
||||
weight = stripWeightViewOps(weight);
|
||||
weight = stripMemRefAddressingOps(weight);
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
|
||||
Reference in New Issue
Block a user