This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
#include "mlir/IR/Block.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
|
||||
if (!yieldOp)
|
||||
return failure();
|
||||
|
||||
for (Value yieldedValue : yieldOp.getOperands()) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArg.getOwner() == &block) {
|
||||
results.push_back(getOperand(blockArg.getArgNumber()));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push_back(yieldedValue);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user