#include "mlir/IR/BuiltinTypeInterfaces.h" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" using namespace mlir; namespace onnx_mlir { Value stripMemRefCasts(Value value) { while (auto castOp = value.getDefiningOp()) value = castOp.getSource(); return value; } Value stripMemRefViewOps(Value value) { while (true) { if (auto castOp = value.getDefiningOp()) { value = castOp.getSource(); continue; } if (auto collapseOp = value.getDefiningOp()) { value = collapseOp.getSrc(); continue; } if (auto expandOp = value.getDefiningOp()) { value = expandOp.getSrc(); continue; } return value; } } bool hasAllStaticSubviewParts(memref::SubViewOp subview) { return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); } FailureOr getStaticSubviewInfo(Value value) { value = stripMemRefViewOps(value); auto subviewOp = value.getDefiningOp(); if (!subviewOp) return failure(); auto source = stripMemRefCasts(subviewOp.getSource()); auto sourceType = dyn_cast(source.getType()); auto subviewType = dyn_cast(subviewOp.getType()); if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return failure(); StaticSubviewInfo info; info.source = source; info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); SmallVector mixedOffsets = subviewOp.getMixedOffsets(); info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end()); for (OpFoldResult size : subviewOp.getMixedSizes()) { auto staticSize = getConstantIntValue(size); if (!staticSize) return failure(); info.sizes.push_back(*staticSize); } for (OpFoldResult stride : subviewOp.getMixedStrides()) { auto staticStride = getConstantIntValue(stride); if (!staticStride) return failure(); info.strides.push_back(*staticStride); } return info; } FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info) { SmallVector staticOffsets; staticOffsets.reserve(info.offsets.size()); for (OpFoldResult offset : info.offsets) { auto staticOffset = getConstantIntValue(offset); if (!staticOffset) return failure(); staticOffsets.push_back(*staticOffset); } return staticOffsets; } } // namespace onnx_mlir