fix MatMul pattern non-contiguous extract_slices
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
@@ -137,6 +139,35 @@ bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(Value value) {
|
||||
SmallPtrSet<Value, 8> visited;
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> bool {
|
||||
if (!visited.insert(currentValue).second)
|
||||
return true;
|
||||
if (currentValue.use_empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) {
|
||||
if (isSpatialMvmVmmWeightUse(use))
|
||||
return true;
|
||||
|
||||
Operation* user = use.getOwner();
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user))
|
||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user))
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(user))
|
||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||
|
||||
return false;
|
||||
});
|
||||
};
|
||||
|
||||
return walkUses(value, walkUses);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
|
||||
Reference in New Issue
Block a user