new ops tests for matmul, grouped conv, concat and reshape
Validate Operations / validate-operations (push) Has been cancelled

related fixes
This commit is contained in:
NiccoloN
2026-05-14 15:54:06 +02:00
parent d09e76c8f9
commit fe244d5aa1
10 changed files with 186 additions and 12 deletions
+52 -11
View File
@@ -33,7 +33,7 @@ struct DenseWeightView {
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
SmallVector<Operation*> viewOps;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
viewOps.push_back(subview);
current = subview.getSource();
continue;
}
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
current = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(collapse);
current = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(expand);
current = expand.getSrc();
continue;
}
return failure();
}
@@ -70,16 +88,39 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
for (Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
// Collapse/expand are accepted only as contiguous static reshapes of a
// dense global view, so a row-major stride recomputation preserves layout.
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
+17 -1
View File
@@ -7,8 +7,8 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -97,6 +97,22 @@ static bool isConstantGlobalView(Value value) {
value = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return false;
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return false;
value = expand.getSrc();
continue;
}
return false;
}
}