#include #include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { std::optional getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) { auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional { if (inputCount == 0) return std::nullopt; unsigned inputBegin = op->getNumOperands() - inputCount; if (operandNumber < inputBegin) return std::nullopt; return operandNumber - inputBegin; }; if (auto compute = dyn_cast(owner)) return getInputIndex(owner, compute.getInputs().size()); if (auto computeBatch = dyn_cast(owner)) return getInputIndex(owner, computeBatch.getInputs().size()); return std::nullopt; } void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, Operation* owner, unsigned inputIndex, Value replacement) { Block& body = owner->getRegion(0).front(); BlockArgument bodyArgument; if (auto compute = dyn_cast(owner)) { auto computeArg = compute.getInputArgument(inputIndex); assert(computeArg && "expected compute input block argument"); bodyArgument = *computeArg; } else { auto batchArg = cast(owner).getInputArgument(inputIndex); assert(batchArg && "expected compute_batch input block argument"); bodyArgument = *batchArg; } unsigned bodyArgIndex = bodyArgument.getArgNumber(); rewriter.startOpModification(owner); bodyArgument.replaceAllUsesWith(replacement); if (auto compute = dyn_cast(owner)) compute.getInputsMutable().erase(inputIndex); else cast(owner).getInputsMutable().erase(inputIndex); body.eraseArgument(bodyArgIndex); rewriter.finalizeOpModification(owner); } } // namespace onnx_mlir