85 lines
2.8 KiB
C++
85 lines
2.8 KiB
C++
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
Value stripMemRefCasts(Value value) {
|
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
|
value = castOp.getSource();
|
|
return value;
|
|
}
|
|
|
|
Value stripMemRefViewOps(Value value) {
|
|
while (true) {
|
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
|
value = castOp.getSource();
|
|
continue;
|
|
}
|
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
|
value = collapseOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
|
value = expandOp.getSrc();
|
|
continue;
|
|
}
|
|
return value;
|
|
}
|
|
}
|
|
|
|
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
|
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
|
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
|
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
|
}
|
|
|
|
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
|
value = stripMemRefViewOps(value);
|
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
|
if (!subviewOp)
|
|
return failure();
|
|
|
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
return failure();
|
|
|
|
StaticSubviewInfo info;
|
|
info.source = source;
|
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
|
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
|
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
auto staticSize = getConstantIntValue(size);
|
|
if (!staticSize)
|
|
return failure();
|
|
info.sizes.push_back(*staticSize);
|
|
}
|
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
auto staticStride = getConstantIntValue(stride);
|
|
if (!staticStride)
|
|
return failure();
|
|
info.strides.push_back(*staticStride);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
|
SmallVector<int64_t> staticOffsets;
|
|
staticOffsets.reserve(info.offsets.size());
|
|
for (OpFoldResult offset : info.offsets) {
|
|
auto staticOffset = getConstantIntValue(offset);
|
|
if (!staticOffset)
|
|
return failure();
|
|
staticOffsets.push_back(*staticOffset);
|
|
}
|
|
return staticOffsets;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|