#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) { return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; }); } static bool hasConstantIndices(tensor::ExtractOp extractOp) { return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); }); } static bool isStaticTensorResult(Operation* op) { return llvm::all_of(op->getResultTypes(), [](Type type) { auto shapedType = dyn_cast(type); return shapedType && shapedType.hasStaticShape(); }); } static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { auto tensorType = dyn_cast(denseAttr.getType()); if (!tensorType) return failure(); int64_t rank = tensorType.getRank(); if (static_cast(perms.size()) != rank) return failure(); llvm::SmallBitVector seen(rank); SmallVector transposedShape; transposedShape.reserve(rank); for (int64_t perm : perms) { if (perm < 0 || perm >= rank || seen.test(perm)) return failure(); seen.set(perm); transposedShape.push_back(tensorType.getShape()[perm]); } auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding()); if (denseAttr.isSplat()) return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue()); SmallVector originalValues(denseAttr.getValues()); SmallVector transposedValues(originalValues.size()); SmallVector originalStrides = computeRowMajorStrides(tensorType.getShape()); SmallVector transposedStrides = computeRowMajorStrides(transposedShape); SmallVector originalIndices(rank); for (auto [linearIndex, value] : llvm::enumerate(originalValues)) { int64_t remaining = static_cast(linearIndex); for (int64_t dim = 0; dim < rank; ++dim) { originalIndices[dim] = remaining / originalStrides[dim]; remaining %= originalStrides[dim]; } int64_t transposedLinearIndex = 0; for (int64_t dim = 0; dim < rank; ++dim) transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim]; transposedValues[transposedLinearIndex] = value; } return DenseElementsAttr::get(transposedType, transposedValues); } static FailureOr reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) { auto sourceType = dyn_cast(denseAttr.getType()); if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements()) return failure(); if (denseAttr.isSplat()) return DenseElementsAttr::get(resultType, denseAttr.getSplatValue()); SmallVector values(denseAttr.getValues()); return DenseElementsAttr::get(resultType, values); } static FailureOr extractSliceDenseElements(DenseElementsAttr denseAttr, tensor::ExtractSliceOp extractSliceOp) { auto sourceType = dyn_cast(denseAttr.getType()); auto resultType = dyn_cast(extractSliceOp.getType()); if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); ArrayRef offsets = extractSliceOp.getStaticOffsets(); ArrayRef sizes = extractSliceOp.getStaticSizes(); ArrayRef strides = extractSliceOp.getStaticStrides(); if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); }) || llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); }) || llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; })) return failure(); if (denseAttr.isSplat()) return DenseElementsAttr::get(resultType, denseAttr.getSplatValue()); SmallVector sourceValues(denseAttr.getValues()); SmallVector sourceStrides = computeRowMajorStrides(sourceType.getShape()); SmallVector resultStrides = computeRowMajorStrides(resultType.getShape()); SmallVector resultValues; resultValues.reserve(resultType.getNumElements()); for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) { int64_t remaining = linearIndex; int64_t sourceLinearIndex = 0; for (int64_t dim = 0; dim < resultType.getRank(); ++dim) { const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim]; remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim]; sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim]; } resultValues.push_back(sourceValues[sourceLinearIndex]); } return DenseElementsAttr::get(resultType, resultValues); } static DenseElementsAttr getDirectDenseConstantAttr(Value value) { if (auto constantOp = value.getDefiningOp()) return dyn_cast(constantOp.getValue()); if (auto constantOp = value.getDefiningOp()) return dyn_cast_or_null(constantOp.getValueAttr()); return nullptr; } static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl& visited) { auto* definingOp = value.getDefiningOp(); if (!definingOp || !visited.insert(definingOp).second) return nullptr; // Rebuild dense attributes through view-only host-foldable chains so later // lowering stages can still recognize grouped/sliced constants. if (auto denseAttr = getDirectDenseConstantAttr(value)) return denseAttr; if (auto transposeOp = dyn_cast(definingOp)) { auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited); if (!inputAttr) return nullptr; SmallVector perm; perm.reserve(transposeOp.getPermAttr().size()); for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange()) perm.push_back(attr.getInt()); auto transposedAttr = transposeDenseElements(inputAttr, perm); return succeeded(transposedAttr) ? *transposedAttr : nullptr; } if (auto transposeOp = dyn_cast(definingOp)) { auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited); if (!inputAttr) return nullptr; SmallVector perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end()); auto transposedAttr = transposeDenseElements(inputAttr, perm); return succeeded(transposedAttr) ? *transposedAttr : nullptr; } if (auto collapseShapeOp = dyn_cast(definingOp)) { auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); if (!inputAttr) return nullptr; auto reshapedAttr = reshapeDenseElements(inputAttr, cast(collapseShapeOp.getType())); return succeeded(reshapedAttr) ? *reshapedAttr : nullptr; } if (auto expandShapeOp = dyn_cast(definingOp)) { auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited); if (!inputAttr) return nullptr; auto reshapedAttr = reshapeDenseElements(inputAttr, cast(expandShapeOp.getType())); return succeeded(reshapedAttr) ? *reshapedAttr : nullptr; } if (auto extractSliceOp = dyn_cast(definingOp)) { auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited); if (!inputAttr) return nullptr; auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp); return succeeded(slicedAttr) ? *slicedAttr : nullptr; } return nullptr; } static std::optional getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl& visited, size_t chainLength = 0) { if (!op) return std::nullopt; if (!visited.insert(op).second) return { {op, chainLength} }; if (isa(op)) return { {op, chainLength} }; chainLength += 1; if (auto extractOp = dyn_cast(op)) return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt; if (!isStaticTensorResult(op)) return std::nullopt; if (auto transposeOp = dyn_cast(op)) return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength); if (auto transposeOp = dyn_cast(op)) return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength); if (auto collapseShapeOp = dyn_cast(op)) return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength); if (auto expandShapeOp = dyn_cast(op)) return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength); if (auto extractSliceOp = dyn_cast(op)) return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength) : std::nullopt; if (auto splatOp = dyn_cast(op)) return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength); if (auto extractRowsOp = dyn_cast(op)) return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength); if (auto concatOp = dyn_cast(op)) { std::optional res = {}; for (auto operandValue : concatOp.getOperands()) { auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength); if (!partialRes) return std::nullopt; if (!res) { res = partialRes; continue; } if (res->chainLength < partialRes->chainLength) res = partialRes; } return res; } return std::nullopt; } } // namespace std::optional getCompileTimeSource(Operation* op) { llvm::SmallPtrSet visited; return getCompileTimeSourceImpl(op, visited); } bool isCompileTimeComputable(Value value) { auto* definingOp = value.getDefiningOp(); if (!definingOp) return false; llvm::SmallPtrSet visited; return getCompileTimeSourceImpl(definingOp, visited).has_value(); } bool isCompileTimeOp(Operation* op) { llvm::SmallPtrSet visited; return getCompileTimeSourceImpl(op, visited).has_value(); } DenseElementsAttr getHostConstDenseElementsAttr(Value value) { llvm::SmallPtrSet visited; return getHostConstantDenseElementsAttrImpl(value, visited); } } // namespace onnx_mlir