36 lines
902 B
C++
36 lines
902 B
C++
#include "mlir/IR/Block.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace spatial {
|
|
|
|
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
|
Block& block = getBody().front();
|
|
if (!llvm::hasSingleElement(block))
|
|
return failure();
|
|
|
|
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
|
|
if (!yieldOp)
|
|
return failure();
|
|
|
|
for (Value yieldedValue : yieldOp.getOperands()) {
|
|
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
|
if (blockArg.getOwner() == &block) {
|
|
results.push_back(getOperand(blockArg.getArgNumber()));
|
|
continue;
|
|
}
|
|
}
|
|
results.push_back(yieldedValue);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
} // namespace spatial
|
|
} // namespace onnx_mlir
|