huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
return shapedType && shapedType.hasStaticShape();
|
||||
});
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return isHostFoldableValue(transposeOp.getData());
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isHostFoldableValue(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(definingOp, visited);
|
||||
}
|
||||
|
||||
bool isHostFoldableOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user