Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@@ -171,6 +172,16 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
@@ -226,6 +237,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
|
||||
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user