48 lines
1.9 KiB
C++
48 lines
1.9 KiB
C++
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
|
|
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
|
|
if (inputCount == 0)
|
|
return std::nullopt;
|
|
unsigned inputBegin = op->getNumOperands() - inputCount;
|
|
if (operandNumber < inputBegin)
|
|
return std::nullopt;
|
|
return operandNumber - inputBegin;
|
|
};
|
|
|
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
|
return getInputIndex(owner, compute.getInputs().size());
|
|
|
|
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
|
|
return getInputIndex(owner, computeBatch.getInputs().size());
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|
Operation* owner,
|
|
unsigned inputIndex,
|
|
Value replacement) {
|
|
Block& body = owner->getRegion(0).front();
|
|
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
|
|
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
|
|
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
|
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
|
|
|
rewriter.startOpModification(owner);
|
|
bodyArgument.replaceAllUsesWith(replacement);
|
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
|
compute.getInputsMutable().erase(inputIndex);
|
|
else
|
|
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
|
body.eraseArgument(bodyArgIndex);
|
|
rewriter.finalizeOpModification(owner);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|