#include "mlir/IR/Block.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/LogicalResult.h" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { Block& block = getBody().front(); if (!llvm::hasSingleElement(block)) return failure(); auto yieldOp = dyn_cast(block.front()); if (!yieldOp) return failure(); for (Value yieldedValue : yieldOp.getOperands()) { if (auto blockArg = dyn_cast(yieldedValue)) { if (blockArg.getOwner() == &block) { results.push_back(getOperand(blockArg.getArgNumber())); continue; } } results.push_back(yieldedValue); } return success(); } } // namespace spatial } // namespace onnx_mlir