diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 0dbf870..783eea5 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -33,7 +33,7 @@ struct DenseWeightView { }; FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { - SmallVector subviews; + SmallVector viewOps; mlir::Value current = weight; memref::GetGlobalOp getGlobalOp; @@ -46,7 +46,7 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value if (auto subview = dyn_cast(defOp)) { if (!hasAllStaticSubviewParts(subview)) return failure(); - subviews.push_back(subview); + viewOps.push_back(subview); current = subview.getSource(); continue; } @@ -54,6 +54,24 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value current = cast.getSource(); continue; } + if (auto collapse = dyn_cast(defOp)) { + auto srcType = dyn_cast(collapse.getSrc().getType()); + auto resultType = dyn_cast(collapse.getResult().getType()); + if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + viewOps.push_back(collapse); + current = collapse.getSrc(); + continue; + } + if (auto expand = dyn_cast(defOp)) { + auto srcType = dyn_cast(expand.getSrc().getType()); + auto resultType = dyn_cast(expand.getResult().getType()); + if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + viewOps.push_back(expand); + current = expand.getSrc(); + continue; + } return failure(); } @@ -70,16 +88,39 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); view.strides = computeRowMajorStrides(view.shape); - for (memref::SubViewOp subview : llvm::reverse(subviews)) { - SmallVector nextStrides; - nextStrides.reserve(subview.getStaticStrides().size()); - for (auto [offset, stride, sourceStride] : - llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { - view.offset += offset * sourceStride; - nextStrides.push_back(stride * sourceStride); + for (Operation* viewOp : llvm::reverse(viewOps)) { + if (auto subview = dyn_cast(viewOp)) { + SmallVector nextStrides; + nextStrides.reserve(subview.getStaticStrides().size()); + for (auto [offset, stride, sourceStride] : + llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { + view.offset += offset * sourceStride; + nextStrides.push_back(stride * sourceStride); + } + view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); + view.strides = std::move(nextStrides); + continue; + } + + // Collapse/expand are accepted only as contiguous static reshapes of a + // dense global view, so a row-major stride recomputation preserves layout. + if (auto collapse = dyn_cast(viewOp)) { + if (view.strides != computeRowMajorStrides(view.shape)) + return failure(); + auto resultType = cast(collapse.getResult().getType()); + view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); + view.strides = computeRowMajorStrides(view.shape); + continue; + } + + if (auto expand = dyn_cast(viewOp)) { + if (view.strides != computeRowMajorStrides(view.shape)) + return failure(); + auto resultType = cast(expand.getResult().getType()); + view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); + view.strides = computeRowMajorStrides(view.shape); + continue; } - view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); - view.strides = std::move(nextStrides); } return view; diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index d5b3b43..46a9eed 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -7,8 +7,8 @@ #include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" -#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -97,6 +97,22 @@ static bool isConstantGlobalView(Value value) { value = cast.getSource(); continue; } + if (auto collapse = dyn_cast(defOp)) { + auto srcType = dyn_cast(collapse.getSrc().getType()); + auto resultType = dyn_cast(collapse.getResult().getType()); + if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) + return false; + value = collapse.getSrc(); + continue; + } + if (auto expand = dyn_cast(defOp)) { + auto srcType = dyn_cast(expand.getSrc().getType()); + auto resultType = dyn_cast(expand.getResult().getType()); + if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) + return false; + value = expand.getSrc(); + continue; + } return false; } } diff --git a/validation/operations/README.md b/validation/operations/README.md index 3b1d979..2f5b977 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -23,6 +23,8 @@ python3 validation/operations/gen_tests.py | Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads | | With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias | | Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input | +| Grouped two groups | `conv/grouped_two_groups` | [1,4,4,4] | [1,4,4,4] | 1x1 | 1 | none | yes | group=2 channel partitioning | +| Depthwise grouped | `conv/depthwise_grouped` | [1,3,4,4] | [1,3,2,2] | 3x3 | 1 | none | no | group=3, one input channel per group | ## Gemm @@ -37,6 +39,13 @@ python3 validation/operations/gen_tests.py | Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices | | transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined | +## MatMul + +| Test | Directory | A input | B weight | Output | Notes | +|------------|---------------------|---------|----------|---------|------------------------------------| +| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path | +| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path | + ## Gemv | Test | Directory | Input | W (weight) | Output | Bias | Notes | @@ -115,6 +124,18 @@ python3 validation/operations/gen_tests.py | Axis 1 | `gather/axis1` | [3,4] | [2] | [3,2] | 1 | Select two columns | | Axis 0 matrix indices| `gather/axis0_matrix_indices` | [4,3] | [2,2] | [2,2,3] | 0 | Gather rows with 2D indices | +## Concat + +| Test | Directory | Input(s) | Output | Axis | Notes | +|--------------|-----------------------|---------------------------|-----------|------|-----------------------------| +| Channel axis | `concat/channel_axis` | A:[1,1,2,2], B:[1,2,2,2] | [1,3,2,2] | 1 | Runtime NCHW channel concat | + +## Reshape + +| Test | Directory | Input | Output | Notes | +|-----------|---------------------|-------|--------|----------------------------------------------| +| Same rank | `reshape/same_rank` | [2,3] | [3,2] | Runtime tensor with static shape initializer | + ## Add | Test | Directory | Input(s) | Output | Notes | diff --git a/validation/operations/concat/channel_axis/concat_channel_axis.onnx b/validation/operations/concat/channel_axis/concat_channel_axis.onnx new file mode 100644 index 0000000..134da55 Binary files /dev/null and b/validation/operations/concat/channel_axis/concat_channel_axis.onnx differ diff --git a/validation/operations/conv/depthwise_grouped/conv_depthwise_grouped.onnx b/validation/operations/conv/depthwise_grouped/conv_depthwise_grouped.onnx new file mode 100644 index 0000000..84da811 Binary files /dev/null and b/validation/operations/conv/depthwise_grouped/conv_depthwise_grouped.onnx differ diff --git a/validation/operations/conv/grouped_two_groups/conv_grouped_two_groups.onnx b/validation/operations/conv/grouped_two_groups/conv_grouped_two_groups.onnx new file mode 100644 index 0000000..162214e Binary files /dev/null and b/validation/operations/conv/grouped_two_groups/conv_grouped_two_groups.onnx differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 7e0309c..53e5821 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -154,6 +154,33 @@ def conv_large_spatial(): save_model(model, "conv/large_spatial", "conv_large_spatial.onnx") +def conv_grouped_two_groups(): + """Grouped Conv with two groups, pointwise kernels, and bias.""" + rng = np.random.default_rng(59) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4]) + W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 2, 1, 1)).astype(np.float32), name="W") + B = numpy_helper.from_array(rng.uniform(-1, 1, (4,)).astype(np.float32), name="B") + node = helper.make_node("Conv", ["X", "W", "B"], ["Y"], + kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0], group=2) + graph = helper.make_graph([node], "conv_grouped_two_groups", [X], [Y], initializer=[W, B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/grouped_two_groups", "conv_grouped_two_groups.onnx") + + +def conv_depthwise_grouped(): + """Depthwise-style grouped Conv with one input channel per group.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2]) + W = numpy_helper.from_array( + np.random.default_rng(60).uniform(-1, 1, (3, 1, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0], group=3) + graph = helper.make_graph([node], "conv_depthwise_grouped", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/depthwise_grouped", "conv_depthwise_grouped.onnx") + + # --------------------------------------------------------------------------- # GEMM tests # --------------------------------------------------------------------------- @@ -252,6 +279,33 @@ def gemm_transB_with_bias(): save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx") +# --------------------------------------------------------------------------- +# MatMul tests +# --------------------------------------------------------------------------- + +def matmul_basic(): + """Direct 2D MatMul with constant RHS.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + B = numpy_helper.from_array(np.random.default_rng(49).uniform(-1, 1, (3, 4)).astype(np.float32), name="B") + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_basic", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/basic", "matmul_basic.onnx") + + +def matmul_batched_3d(): + """Batched 3D MatMul with matching batch dimensions.""" + rng = np.random.default_rng(50) + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4]) + B = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 4)).astype(np.float32), name="B") + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_batched_3d", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/batched_3d", "matmul_batched_3d.onnx") + + # --------------------------------------------------------------------------- # Pooling tests # --------------------------------------------------------------------------- @@ -607,6 +661,36 @@ def gather_axis0_matrix_indices(): save_model(model, "gather/axis0_matrix_indices", "gather_axis0_matrix_indices.onnx") +# --------------------------------------------------------------------------- +# Concat tests +# --------------------------------------------------------------------------- + +def concat_channel_axis(): + """Concat two runtime NCHW tensors along the channel axis.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1, 2, 2]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 2, 2, 2]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2]) + node = helper.make_node("Concat", ["A", "B"], ["Y"], axis=1) + graph = helper.make_graph([node], "concat_channel_axis", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "concat/channel_axis", "concat_channel_axis.onnx") + + +# --------------------------------------------------------------------------- +# Reshape tests +# --------------------------------------------------------------------------- + +def reshape_same_rank(): + """Runtime tensor Reshape with a static shape initializer and unchanged rank.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2]) + shape = make_int64_initializer("shape", [3, 2]) + node = helper.make_node("Reshape", ["X", "shape"], ["Y"]) + graph = helper.make_graph([node], "reshape_same_rank", [X], [Y], initializer=[shape]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "reshape/same_rank", "reshape_same_rank.onnx") + + # --------------------------------------------------------------------------- # Add tests # --------------------------------------------------------------------------- @@ -757,6 +841,12 @@ if __name__ == "__main__": conv_with_bias_3x3() conv_batch_2() conv_large_spatial() + conv_grouped_two_groups() + conv_depthwise_grouped() + + print("\nGenerating MatMul tests:") + matmul_basic() + matmul_batched_3d() print("\nGenerating Pooling tests:") maxpool_basic() @@ -802,6 +892,12 @@ if __name__ == "__main__": gather_axis1() gather_axis0_matrix_indices() + print("\nGenerating Concat tests:") + concat_channel_axis() + + print("\nGenerating Reshape tests:") + reshape_same_rank() + print("\nGenerating Add tests:") add_basic() add_broadcast_row() diff --git a/validation/operations/matmul/basic/matmul_basic.onnx b/validation/operations/matmul/basic/matmul_basic.onnx new file mode 100644 index 0000000..4d0b920 Binary files /dev/null and b/validation/operations/matmul/basic/matmul_basic.onnx differ diff --git a/validation/operations/matmul/batched_3d/matmul_batched_3d.onnx b/validation/operations/matmul/batched_3d/matmul_batched_3d.onnx new file mode 100644 index 0000000..d58cd3b Binary files /dev/null and b/validation/operations/matmul/batched_3d/matmul_batched_3d.onnx differ diff --git a/validation/operations/reshape/same_rank/reshape_same_rank.onnx b/validation/operations/reshape/same_rank/reshape_same_rank.onnx new file mode 100644 index 0000000..7942128 Binary files /dev/null and b/validation/operations/reshape/same_rank/reshape_same_rank.onnx differ