From 11916a259572b329e16028b03cd28b68e46f25ab Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 23 Mar 2026 15:36:58 +0100 Subject: [PATCH] refactor Pim constant folding pass share contiguous address resolution in PimCommon group patterns in subdir for each pass with pattern files --- src/PIM/CMakeLists.txt | 5 +- src/PIM/Common/PimCommon.cpp | 63 +++ src/PIM/Common/PimCommon.hpp | 7 + src/PIM/Compiler/PimCodeGen.cpp | 62 +-- .../Conversion/ONNXToSpatial/CMakeLists.txt | 18 +- .../{ONNXToSpatialCommon.cpp => Common.cpp} | 2 +- .../{ONNXToSpatialCommon.hpp => Common.hpp} | 0 .../ONNXToSpatial/ONNXToSpatialPass.cpp | 2 +- ...ONNXToSpatialPatterns.hpp => Patterns.hpp} | 0 .../{ => Patterns}/Math/Conv.cpp | 0 .../{ => Patterns}/Math/Gemm.cpp | 2 +- .../{ => Patterns}/Math/MatMul.cpp | 2 +- .../{ => Patterns}/NN/Pooling.cpp | 2 +- .../{ => Patterns}/NN/ReduceMean.cpp | 2 +- .../Tensor/ONNXConcatToTensorConcat.cpp | 2 +- .../Tensor/ONNXReshapeToTensorReshape.cpp | 2 +- .../Tensor/RemoveUnusedHelperOps.cpp | 2 +- .../Utils/AnnotateReplication.cpp | 2 +- .../ONNXToSpatial/Utils/SpatialReducer.hpp | 2 +- .../Conversion/SpatialToPim/CMakeLists.txt | 2 +- .../{SpatialToPimCommon.cpp => Common.cpp} | 2 +- .../{SpatialToPimCommon.hpp => Common.hpp} | 0 .../SpatialToPim/SpatialToPimPass.cpp | 4 +- .../SpatialToPim/SpatialToPimPatterns.hpp | 12 - src/PIM/Dialect/Spatial/SpatialOps.cpp | 2 +- src/PIM/Pass/PimConstantFolding/Common.cpp | 121 +++++ src/PIM/Pass/PimConstantFolding/Common.hpp | 41 ++ src/PIM/Pass/PimConstantFolding/Patterns.hpp | 11 + .../Patterns/ConstantPatterns.cpp} | 419 +----------------- .../Patterns/SubviewPatterns.cpp | 223 ++++++++++ .../PimConstantFoldingPass.cpp | 53 +++ src/PIM/Pass/PimHostVerificationPass.cpp | 65 ++- 32 files changed, 616 insertions(+), 516 deletions(-) rename src/PIM/Conversion/ONNXToSpatial/{ONNXToSpatialCommon.cpp => Common.cpp} (99%) rename src/PIM/Conversion/ONNXToSpatial/{ONNXToSpatialCommon.hpp => Common.hpp} (100%) rename src/PIM/Conversion/ONNXToSpatial/{ONNXToSpatialPatterns.hpp => Patterns.hpp} (100%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/Math/Conv.cpp (100%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/Math/Gemm.cpp (99%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/Math/MatMul.cpp (98%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/NN/Pooling.cpp (99%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/NN/ReduceMean.cpp (98%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/Tensor/ONNXConcatToTensorConcat.cpp (91%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/Tensor/ONNXReshapeToTensorReshape.cpp (97%) rename src/PIM/Conversion/ONNXToSpatial/{ => Patterns}/Tensor/RemoveUnusedHelperOps.cpp (93%) rename src/PIM/Conversion/SpatialToPim/{SpatialToPimCommon.cpp => Common.cpp} (98%) rename src/PIM/Conversion/SpatialToPim/{SpatialToPimCommon.hpp => Common.hpp} (100%) delete mode 100644 src/PIM/Conversion/SpatialToPim/SpatialToPimPatterns.hpp create mode 100644 src/PIM/Pass/PimConstantFolding/Common.cpp create mode 100644 src/PIM/Pass/PimConstantFolding/Common.hpp create mode 100644 src/PIM/Pass/PimConstantFolding/Patterns.hpp rename src/PIM/Pass/{PimConstantFoldingPass.cpp => PimConstantFolding/Patterns/ConstantPatterns.cpp} (51%) create mode 100644 src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp create mode 100644 src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index 9a62b7f..1d310d3 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -20,7 +20,10 @@ add_onnx_mlir_library(OMPIMAccel Pass/CountInstructionPass.cpp Pass/EmitPimJsonPass.cpp Pass/MessagePass.cpp - Pass/PimConstantFoldingPass.cpp + Pass/PimConstantFolding/Common.cpp + Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp + Pass/PimConstantFolding/PimConstantFoldingPass.cpp + Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp Pass/PimHostVerificationPass.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index f13ccca..6876d38 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,3 +1,6 @@ +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + #include "llvm/Support/raw_os_ostream.h" #include @@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef srcShape, return true; } +FailureOr resolveContiguousAddress(Value value) { + int64_t byteOffset = 0; + + while (true) { + if (isa(value)) + return ResolvedContiguousAddress{value, byteOffset}; + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return failure(); + + if (auto dpsDefiningOp = dyn_cast(definingOp)) { + OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast(value)); + if (!tiedOperand) + return failure(); + value = tiedOperand->get(); + continue; + } + + if (auto subviewOp = dyn_cast(definingOp)) { + auto sourceType = dyn_cast(subviewOp.getSource().getType()); + auto subviewType = dyn_cast(subviewOp.getType()); + if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) + return failure(); + + ArrayRef offsets = subviewOp.getStaticOffsets(); + ArrayRef sizes = subviewOp.getStaticSizes(); + ArrayRef strides = subviewOp.getStaticStrides(); + if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic) + || llvm::is_contained(strides, ShapedType::kDynamic)) + return failure(); + if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) + return failure(); + + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; + value = subviewOp.getSource(); + continue; + } + + if (auto castOp = dyn_cast(definingOp)) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = dyn_cast(definingOp)) { + value = expandOp.getSrc(); + continue; + } + + if (isa(definingOp)) + return ResolvedContiguousAddress{value, byteOffset}; + + return failure(); + } +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 89a6aa6..0b9fadd 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { +struct ResolvedContiguousAddress { + mlir::Value base; + int64_t byteOffset = 0; +}; + std::string getOutputDir(); void createDirectory(const std::string& directory); @@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef sizes, llvm::ArrayRef strides); +llvm::FailureOr resolveContiguousAddress(mlir::Value value); + } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 28babf4..c50f897 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -14,7 +14,7 @@ #include #include "Common/PimCommon.hpp" -#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -86,48 +86,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { } size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const { - size_t offset = 0; - while (true) { - auto definingOp = value.getDefiningOp(); - if (!definingOp) - break; - if (auto dpsDefiningOp = dyn_cast(definingOp)) { - OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); - if (!tiedOperand) - break; - value = tiedOperand->get(); - } - else if (auto subviewDefiningOp = dyn_cast(definingOp)) { - auto source = subviewDefiningOp.getSource(); - auto srcShape = source.getType().getShape(); - auto subviewOffsets = subviewDefiningOp.getStaticOffsets(); - auto subviewSizes = subviewDefiningOp.getStaticSizes(); - auto subviewStrides = subviewDefiningOp.getStaticStrides(); - assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides)); - for (unsigned i = 0; i < subviewOffsets.size(); i++) { - size_t localOffset = subviewOffsets[i]; - for (unsigned j = i + 1; j < subviewSizes.size(); j++) - localOffset *= subviewSizes[j]; - offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8; - } - value = source; - } - else if (auto castOp = dyn_cast(definingOp)) { - value = castOp.getSource(); - } - else if (auto collapseOp = dyn_cast(definingOp)) { - value = collapseOp.getSrc(); - } - else if (auto expandOp = dyn_cast(definingOp)) { - value = expandOp.getSrc(); - } - else - break; - } - - auto iter = memEntriesMap.find(value); - if (iter == memEntriesMap.end()) { - errs() << "Missing mem entry for value: "; + auto resolvedAddress = resolveContiguousAddress(value); + if (failed(resolvedAddress)) { + errs() << "Failed to resolve contiguous address for value: "; value.print(errs()); errs() << "\n"; if (auto* definingOp = value.getDefiningOp()) { @@ -135,10 +96,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const { definingOp->print(errs()); errs() << "\n"; } + llvm_unreachable("Failed to resolve contiguous address"); + } + + auto iter = memEntriesMap.find(resolvedAddress->base); + if (iter == memEntriesMap.end()) { + errs() << "Missing mem entry for value: "; + resolvedAddress->base.print(errs()); + errs() << "\n"; + if (auto* definingOp = resolvedAddress->base.getDefiningOp()) { + errs() << "Defining op:\n"; + definingOp->print(errs()); + errs() << "\n"; + } llvm_unreachable("Missing mem entry"); } - return iter->second.address + offset; + return iter->second.address + resolvedAddress->byteOffset; } json::Object PimCodeGen::createEmptyOffset() { diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index ea95ce8..a77cde2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -3,19 +3,19 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(ONNXToSpatialIncGen) add_onnx_mlir_library(OMONNXToSpatial - Math/Gemm.cpp - Math/Conv.cpp - Math/MatMul.cpp - NN/Pooling.cpp - NN/ReduceMean.cpp - Tensor/ONNXConcatToTensorConcat.cpp - Tensor/ONNXReshapeToTensorReshape.cpp - Tensor/RemoveUnusedHelperOps.cpp + Patterns/Math/Gemm.cpp + Patterns/Math/Conv.cpp + Patterns/Math/MatMul.cpp + Patterns/NN/Pooling.cpp + Patterns/NN/ReduceMean.cpp + Patterns/Tensor/ONNXConcatToTensorConcat.cpp + Patterns/Tensor/ONNXReshapeToTensorReshape.cpp + Patterns/Tensor/RemoveUnusedHelperOps.cpp Utils/SpatialReducer.cpp Utils/WeightSubdivider.cpp Utils/AnnotateReplication.cpp ONNXToSpatialPass.cpp - ONNXToSpatialCommon.cpp + Common.cpp DEPENDS ONNXToSpatialIncGen diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp b/src/PIM/Conversion/ONNXToSpatial/Common.cpp similarity index 99% rename from src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp rename to src/PIM/Conversion/ONNXToSpatial/Common.cpp index e633c34..03cc524 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.cpp @@ -15,7 +15,7 @@ #include #include -#include "ONNXToSpatialCommon.hpp" +#include "Common.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp similarity index 100% rename from src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp rename to src/PIM/Conversion/ONNXToSpatial/Common.hpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 02f92c8..b6faf7d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -11,7 +11,7 @@ #include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp similarity index 100% rename from src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns.hpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp similarity index 100% rename from src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp similarity index 99% rename from src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 0ca061c..ab67abf 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -11,7 +11,7 @@ #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp similarity index 98% rename from src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index a37cd74..d8dda1d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -4,7 +4,7 @@ #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pooling.cpp similarity index 99% rename from src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pooling.cpp index cc0b3c5..6b4c104 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pooling.cpp @@ -17,7 +17,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/ReduceMean.cpp similarity index 98% rename from src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/NN/ReduceMean.cpp index e55693d..859fa7b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/ReduceMean.cpp @@ -1,6 +1,6 @@ #include "mlir/Transforms/DialectConversion.h" -#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXConcatToTensorConcat.cpp similarity index 91% rename from src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXConcatToTensorConcat.cpp index d640073..a73c547 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXConcatToTensorConcat.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXReshapeToTensorReshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXReshapeToTensorReshape.cpp similarity index 97% rename from src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXReshapeToTensorReshape.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXReshapeToTensorReshape.cpp index 629b0ce..ca023aa 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXReshapeToTensorReshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXReshapeToTensorReshape.cpp @@ -3,7 +3,7 @@ #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp similarity index 93% rename from src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp index 6fb9bf7..d609c14 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp index c9092b1..289eda4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp @@ -1,7 +1,7 @@ #include #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp index 6739d48..15b26d7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp @@ -7,7 +7,7 @@ #include #include -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index 33b42d8..171109b 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -4,7 +4,7 @@ add_public_tablegen_target(SpatialToPimIncGen) add_onnx_mlir_library(OMSpatialToPim SpatialToPimPass.cpp - SpatialToPimCommon.cpp + Common.cpp DEPENDS SpatialToPimIncGen diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp similarity index 98% rename from src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.cpp rename to src/PIM/Conversion/SpatialToPim/Common.cpp index 6c9c528..073a8b6 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -5,7 +5,7 @@ #include #include -#include "SpatialToPimCommon.hpp" +#include "Common.hpp" using namespace llvm; using namespace mlir; diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp similarity index 100% rename from src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp rename to src/PIM/Conversion/SpatialToPim/Common.hpp diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 2ffdb7b..f614179 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -19,9 +19,9 @@ #include #include -#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPatterns.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPatterns.hpp deleted file mode 100644 index 365efe9..0000000 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPatterns.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once -#include "mlir/IR/PatternMatch.h" - -namespace onnx_mlir { - -namespace spatial { - -// TODO: Add here eventual patterns - -} - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 82dae7e..de1c1be 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -24,7 +24,7 @@ #include #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" diff --git a/src/PIM/Pass/PimConstantFolding/Common.cpp b/src/PIM/Pass/PimConstantFolding/Common.cpp new file mode 100644 index 0000000..14df695 --- /dev/null +++ b/src/PIM/Pass/PimConstantFolding/Common.cpp @@ -0,0 +1,121 @@ +#include "Common.hpp" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +Value stripMemRefCasts(Value value) { + while (auto castOp = value.getDefiningOp()) + value = castOp.getSource(); + return value; +} + +Value stripMemRefViewOps(Value value) { + while (true) { + if (auto castOp = value.getDefiningOp()) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = value.getDefiningOp()) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = value.getDefiningOp()) { + value = expandOp.getSrc(); + continue; + } + return value; + } +} + +memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, + Location loc, + MemRefType globalType, + DenseElementsAttr denseAttr, + StringRef nameStem, + IntegerAttr alignment) { + auto globalName = nameStem.str(); + unsigned suffix = 0; + while (moduleOp.lookupSymbol(globalName)) + globalName = (nameStem + "_" + std::to_string(++suffix)).str(); + + auto visibility = StringAttr::get(moduleOp.getContext(), "private"); + OpBuilder moduleBuilder(moduleOp.getBodyRegion()); + moduleBuilder.setInsertionPointToStart(moduleOp.getBody()); + return memref::GlobalOp::create(moduleBuilder, + loc, + globalName, + visibility, + globalType, + denseAttr, + /*constant=*/true, + alignment); +} + +FailureOr getDenseGlobalValue(ModuleOp moduleOp, Value value) { + value = stripMemRefCasts(value); + + auto getGlobalOp = value.getDefiningOp(); + if (!getGlobalOp) + return failure(); + + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) + return failure(); + + auto denseAttr = dyn_cast(*globalOp.getInitialValue()); + if (!denseAttr) + return failure(); + return denseAttr; +} + +FailureOr getStaticSubviewInfo(Value value) { + value = stripMemRefViewOps(value); + auto subviewOp = value.getDefiningOp(); + if (!subviewOp) + return failure(); + + auto source = stripMemRefCasts(subviewOp.getSource()); + auto sourceType = dyn_cast(source.getType()); + auto subviewType = dyn_cast(subviewOp.getType()); + if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) + return failure(); + + StaticSubviewInfo info; + info.source = source; + info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); + for (OpFoldResult offset : subviewOp.getMixedOffsets()) { + auto staticOffset = getConstantIntValue(offset); + if (!staticOffset) + return failure(); + info.offsets.push_back(*staticOffset); + } + for (OpFoldResult size : subviewOp.getMixedSizes()) { + auto staticSize = getConstantIntValue(size); + if (!staticSize) + return failure(); + info.sizes.push_back(*staticSize); + } + for (OpFoldResult stride : subviewOp.getMixedStrides()) { + auto staticStride = getConstantIntValue(stride); + if (!staticStride) + return failure(); + info.strides.push_back(*staticStride); + } + return info; +} + +int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, + ArrayRef outerIndices, + int64_t elementByteWidth) { + SmallVector sourceIndices; + sourceIndices.reserve(info.sourceShape.size()); + for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim) + sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]); + sourceIndices.push_back(info.offsets.back()); + return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimConstantFolding/Common.hpp b/src/PIM/Pass/PimConstantFolding/Common.hpp new file mode 100644 index 0000000..0483b6a --- /dev/null +++ b/src/PIM/Pass/PimConstantFolding/Common.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace onnx_mlir { + +struct StaticSubviewInfo { + mlir::Value source; + llvm::SmallVector sourceShape; + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; +}; + +mlir::Value stripMemRefCasts(mlir::Value value); + +mlir::Value stripMemRefViewOps(mlir::Value value); + +mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp, + mlir::Location loc, + mlir::MemRefType globalType, + mlir::DenseElementsAttr denseAttr, + llvm::StringRef nameStem, + mlir::IntegerAttr alignment = {}); + +llvm::FailureOr getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value); + +llvm::FailureOr getStaticSubviewInfo(mlir::Value value); + +int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, + llvm::ArrayRef outerIndices, + int64_t elementByteWidth); + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimConstantFolding/Patterns.hpp b/src/PIM/Pass/PimConstantFolding/Patterns.hpp new file mode 100644 index 0000000..a0d9894 --- /dev/null +++ b/src/PIM/Pass/PimConstantFolding/Patterns.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns); + +void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns); + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimConstantFoldingPass.cpp b/src/PIM/Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp similarity index 51% rename from src/PIM/Pass/PimConstantFoldingPass.cpp rename to src/PIM/Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp index bf70814..3580589 100644 --- a/src/PIM/Pass/PimConstantFoldingPass.cpp +++ b/src/PIM/Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp @@ -1,19 +1,11 @@ -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "../Common.hpp" +#include "../Patterns.hpp" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/Matchers.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" - -#include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -21,73 +13,14 @@ using namespace mlir; namespace onnx_mlir { - namespace { -static Value stripMemRefCasts(Value value) { - while (auto castOp = value.getDefiningOp()) - value = castOp.getSource(); - return value; -} - -static Value stripMemRefViewOps(Value value) { - while (true) { - if (auto castOp = value.getDefiningOp()) { - value = castOp.getSource(); - continue; - } - if (auto collapseOp = value.getDefiningOp()) { - value = collapseOp.getSrc(); - continue; - } - if (auto expandOp = value.getDefiningOp()) { - value = expandOp.getSrc(); - continue; - } - return value; - } -} - -static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, - Location loc, - MemRefType globalType, - DenseElementsAttr denseAttr, - StringRef nameStem, - IntegerAttr alignment = {}) { - auto globalName = nameStem.str(); - unsigned suffix = 0; - while (moduleOp.lookupSymbol(globalName)) - globalName = (nameStem + "_" + std::to_string(++suffix)).str(); - - auto visibility = StringAttr::get(moduleOp.getContext(), "private"); - OpBuilder moduleBuilder(moduleOp.getBodyRegion()); - moduleBuilder.setInsertionPointToStart(moduleOp.getBody()); - return memref::GlobalOp::create(moduleBuilder, - loc, - globalName, - visibility, - globalType, - denseAttr, - /*constant=*/true, - alignment); -} - -static FailureOr getDenseGlobalValue(ModuleOp moduleOp, Value value) { - value = stripMemRefCasts(value); - - auto getGlobalOp = value.getDefiningOp(); - if (!getGlobalOp) - return failure(); - - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) - return failure(); - - auto denseAttr = dyn_cast(*globalOp.getInitialValue()); - if (!denseAttr) - return failure(); - return denseAttr; -} +struct ConstantSubviewCopy { + DenseElementsAttr source; + SmallVector offsets; + SmallVector strides; + Operation* copyOp = nullptr; +}; static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { auto tensorType = dyn_cast(denseAttr.getType()); @@ -144,13 +77,6 @@ static FailureOr transposeDenseElements(DenseElementsAttr den return DenseElementsAttr::get(transposedType, transposedValues); } -struct ConstantSubviewCopy { - DenseElementsAttr source; - SmallVector offsets; - SmallVector strides; - Operation* copyOp = nullptr; -}; - static FailureOr getConstantMapYield(linalg::MapOp mapOp) { if (!mapOp.getInputs().empty()) return failure(); @@ -213,204 +139,6 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { } }; -struct StaticSubviewInfo { - Value source; - SmallVector sourceShape; - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -static FailureOr getStaticSubviewInfo(Value value) { - value = stripMemRefViewOps(value); - auto subviewOp = value.getDefiningOp(); - if (!subviewOp) - return failure(); - - auto source = stripMemRefCasts(subviewOp.getSource()); - auto sourceType = dyn_cast(source.getType()); - auto subviewType = dyn_cast(subviewOp.getType()); - if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) - return failure(); - - StaticSubviewInfo info; - info.source = source; - info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); - for (OpFoldResult offset : subviewOp.getMixedOffsets()) { - auto staticOffset = getConstantIntValue(offset); - if (!staticOffset) - return failure(); - info.offsets.push_back(*staticOffset); - } - for (OpFoldResult size : subviewOp.getMixedSizes()) { - auto staticSize = getConstantIntValue(size); - if (!staticSize) - return failure(); - info.sizes.push_back(*staticSize); - } - for (OpFoldResult stride : subviewOp.getMixedStrides()) { - auto staticStride = getConstantIntValue(stride); - if (!staticStride) - return failure(); - info.strides.push_back(*staticStride); - } - return info; -} - -static int64_t -getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef outerIndices, int64_t elementByteWidth) { - SmallVector sourceIndices; - sourceIndices.reserve(info.sourceShape.size()); - for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim) - sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]); - sourceIndices.push_back(info.offsets.back()); - return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth; -} - -struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { - if (!copyOp->getParentOfType()) - return failure(); - - auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); - auto dstSubview = getStaticSubviewInfo(copyOp.getDst()); - const bool splitSrc = succeeded(srcSubview) - && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); - const bool splitDst = succeeded(dstSubview) - && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); - if (!splitSrc && !splitDst) - return failure(); - - auto sourceType = dyn_cast(copyOp.getSrc().getType()); - auto dstType = dyn_cast(copyOp.getDst().getType()); - if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) - return failure(); - if (sourceType.getElementType() != dstType.getElementType()) - return failure(); - - if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) - return failure(); - if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) - return failure(); - - ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); - if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) - return failure(); - - const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; - if (elementByteWidth <= 0) - return failure(); - - const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; - if (copyOp.getSize() != totalBytes) - return failure(); - - const int64_t sliceBytes = copyShape.back() * elementByteWidth; - if (sliceBytes <= 0) - return failure(); - - SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); - auto outerStrides = computeRowMajorStrides(outerShape); - const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); - - rewriter.setInsertionPoint(copyOp); - for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { - SmallVector outerIndices = - outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); - const int64_t srcByteOffset = copyOp.getSrcOffset() - + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) - : linearIndex * sliceBytes); - const int64_t dstByteOffset = copyOp.getDstOffset() - + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) - : linearIndex * sliceBytes); - pim::PimMemCopyOp::create(rewriter, - copyOp.getLoc(), - splitDst ? cast(dstSubview->source.getType()) : dstType, - splitDst ? dstSubview->source : copyOp.getDst(), - splitSrc ? srcSubview->source : copyOp.getSrc(), - rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - } - - rewriter.replaceOp(copyOp, copyOp.getDst()); - return success(); - } -}; - -struct RewriteHostSubviewLoadPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { - auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc()); - auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst()); - const bool splitSrc = succeeded(srcSubview) - && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); - const bool splitDst = succeeded(dstSubview) - && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); - if (!splitSrc && !splitDst) - return failure(); - - auto sourceType = dyn_cast(copyOp.getHostSrc().getType()); - auto dstType = dyn_cast(copyOp.getDeviceDst().getType()); - if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) - return failure(); - if (sourceType.getElementType() != dstType.getElementType()) - return failure(); - - if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) - return failure(); - if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) - return failure(); - - ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); - if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) - return failure(); - - const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; - if (elementByteWidth <= 0) - return failure(); - - const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; - if (copyOp.getSize() != totalBytes) - return failure(); - - const int64_t sliceBytes = copyShape.back() * elementByteWidth; - if (sliceBytes <= 0) - return failure(); - - SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); - auto outerStrides = computeRowMajorStrides(outerShape); - const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); - - rewriter.setInsertionPoint(copyOp); - for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { - SmallVector outerIndices = - outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); - const int64_t srcByteOffset = copyOp.getHostSrcOffset() - + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) - : linearIndex * sliceBytes); - const int64_t dstByteOffset = copyOp.getDeviceDstOffset() - + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) - : linearIndex * sliceBytes); - pim::PimMemCopyHostToDevOp::create( - rewriter, - copyOp.getLoc(), - splitDst ? cast(dstSubview->source.getType()) : dstType, - splitDst ? dstSubview->source : copyOp.getDeviceDst(), - splitSrc ? srcSubview->source : copyOp.getHostSrc(), - rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - } - - rewriter.replaceOp(copyOp, copyOp.getDeviceDst()); - return success(); - } -}; - static FailureOr foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) { auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) @@ -563,17 +291,15 @@ struct FoldConstantTransposePattern final : OpRewritePatterngetUsers().empty() @@ -672,11 +398,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { - // Only match top-level memcp (not inside pim.core) if (copyOp->getParentOfType()) return failure(); - // dst must be an alloc with static shape auto allocOp = copyOp.getDst().getDefiningOp(); if (!allocOp) return failure(); @@ -684,11 +408,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { if (!allocType || !allocType.hasStaticShape()) return failure(); - // The copy must cover the full destination (offsets both zero) if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0) return failure(); - // Resolve the source through an optional subview to a get_global auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc()); @@ -700,14 +422,12 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { if (failed(denseAttr)) return failure(); - // Build the folded dense attribute DenseElementsAttr foldedAttr; if (succeeded(srcSubview)) { - // Extract the sub-tensor from the source constant auto sourceType = dyn_cast(denseAttr->getType()); if (!sourceType || !sourceType.hasStaticShape()) return failure(); - if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; })) + if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); @@ -729,14 +449,12 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); } else { - // Direct copy from a global — just reuse its dense attribute auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); if (resultTensorType != denseAttr->getType()) return failure(); foldedAttr = *denseAttr; } - // Verify that the alloc's remaining users are supported ops. bool allLiveUsersAreCores = true; for (Operation* user : allocOp->getUsers()) { if (user == copyOp) @@ -769,110 +487,13 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { } }; -struct FoldConstantCoreSubviewPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override { - // Only handle subviews whose users are all pim.core ops. - if (subviewOp.use_empty()) - return failure(); - if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa(user); })) - return failure(); - - // Source must resolve to a constant get_global. - auto moduleOp = subviewOp->getParentOfType(); - if (!moduleOp) - return failure(); - auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource())); - if (failed(denseAttr)) - return failure(); - - // Static subview info. - auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult()); - if (failed(subviewInfo)) - return failure(); - if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; })) - return failure(); - - auto sourceType = dyn_cast(denseAttr->getType()); - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); - - // Build the contiguous result type. - auto elementType = cast(subviewOp.getType()).getElementType(); - auto resultMemRefType = MemRefType::get( - SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); - auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType); - const int64_t numResultElements = resultTensorType.getNumElements(); - - // Extract the sub-tensor. - auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); - auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); - SmallVector sourceValues(denseAttr->getValues()); - SmallVector resultValues(numResultElements); - for (int64_t i = 0; i < numResultElements; ++i) { - auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); - SmallVector sourceIndices; - sourceIndices.reserve(resultIndices.size()); - for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices)) - sourceIndices.push_back(off + idx); - resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; - } - auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); - - auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); - markWeightAlways(newGlobal); - - rewriter.setInsertionPoint(subviewOp); - auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); - markWeightAlways(newGetGlobal); - - rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); - return success(); - } -}; - -struct PimConstantFoldingPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) - - StringRef getArgument() const override { return "pim-constant-folding-pass"; } - StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } - - LogicalResult initialize(MLIRContext* context) override { - RewritePatternSet owningPatterns(context); - for (auto* dialect : context->getLoadedDialects()) - dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) - op.getCanonicalizationPatterns(owningPatterns, context); - owningPatterns - .add( - context); - patterns = std::make_shared(std::move(owningPatterns)); - return success(); - } - - void runOnOperation() override { - GreedyRewriteConfig config; - config.enableFolding(); - if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) { - signalPassFailure(); - return; - } - - dumpModule(getOperation(), "pim2_folded"); - } - - std::shared_ptr patterns; -}; - } // namespace -std::unique_ptr createPimConstantFoldingPass() { return std::make_unique(); } +void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp b/src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp new file mode 100644 index 0000000..772f57b --- /dev/null +++ b/src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp @@ -0,0 +1,223 @@ +#include "../Common.hpp" +#include "../Patterns.hpp" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +template +static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, + Value dst, + Value src, + int64_t dstOffset, + int64_t srcOffset, + int64_t size, + PatternRewriter& rewriter, + CreateCopyOp createCopyOp) { + auto srcSubview = getStaticSubviewInfo(src); + auto dstSubview = getStaticSubviewInfo(dst); + const bool splitSrc = succeeded(srcSubview) + && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); + const bool splitDst = succeeded(dstSubview) + && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); + if (!splitSrc && !splitDst) + return failure(); + + auto sourceType = dyn_cast(src.getType()); + auto dstType = dyn_cast(dst.getType()); + if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) + return failure(); + if (sourceType.getElementType() != dstType.getElementType()) + return failure(); + + if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + + ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); + if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) + return failure(); + + const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; + if (elementByteWidth <= 0) + return failure(); + + const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; + if (size != totalBytes) + return failure(); + + const int64_t sliceBytes = copyShape.back() * elementByteWidth; + if (sliceBytes <= 0) + return failure(); + + SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); + auto outerStrides = computeRowMajorStrides(outerShape); + const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); + + rewriter.setInsertionPoint(copyOp); + for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { + SmallVector outerIndices = + outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); + const int64_t srcByteOffset = srcOffset + + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) + : linearIndex * sliceBytes); + const int64_t dstByteOffset = dstOffset + + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) + : linearIndex * sliceBytes); + createCopyOp(splitDst ? cast(dstSubview->source.getType()) : dstType, + splitDst ? dstSubview->source : dst, + splitSrc ? srcSubview->source : src, + dstByteOffset, + srcByteOffset, + sliceBytes); + } + + return success(); +} + +struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { + if (!copyOp->getParentOfType()) + return failure(); + + auto status = + rewriteSubviewCopyLikeOp(copyOp, + copyOp.getDst(), + copyOp.getSrc(), + copyOp.getDstOffset(), + copyOp.getSrcOffset(), + copyOp.getSize(), + rewriter, + [&](MemRefType resultType, + Value dst, + Value src, + int64_t dstByteOffset, + int64_t srcByteOffset, + int64_t sliceBytes) { + pim::PimMemCopyOp::create( + rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getDst()); + return success(); + } +}; + +struct RewriteHostSubviewLoadPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { + auto status = + rewriteSubviewCopyLikeOp(copyOp, + copyOp.getDeviceDst(), + copyOp.getHostSrc(), + copyOp.getDeviceDstOffset(), + copyOp.getHostSrcOffset(), + copyOp.getSize(), + rewriter, + [&](MemRefType resultType, + Value dst, + Value src, + int64_t dstByteOffset, + int64_t srcByteOffset, + int64_t sliceBytes) { + pim::PimMemCopyHostToDevOp::create( + rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getDeviceDst()); + return success(); + } +}; + +struct FoldConstantCoreSubviewPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override { + if (subviewOp.use_empty()) + return failure(); + if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa(user); })) + return failure(); + + auto moduleOp = subviewOp->getParentOfType(); + if (!moduleOp) + return failure(); + auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource())); + if (failed(denseAttr)) + return failure(); + + auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult()); + if (failed(subviewInfo)) + return failure(); + if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + + auto sourceType = dyn_cast(denseAttr->getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + + auto elementType = cast(subviewOp.getType()).getElementType(); + auto resultMemRefType = + MemRefType::get(SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); + auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType); + const int64_t numResultElements = resultTensorType.getNumElements(); + + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); + SmallVector sourceValues(denseAttr->getValues()); + SmallVector resultValues(numResultElements); + for (int64_t i = 0; i < numResultElements; ++i) { + auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); + SmallVector sourceIndices; + sourceIndices.reserve(resultIndices.size()); + for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices)) + sourceIndices.push_back(off + idx); + resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; + } + auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); + + auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); + markWeightAlways(newGlobal); + + rewriter.setInsertionPoint(subviewOp); + auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); + markWeightAlways(newGetGlobal); + + rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); + return success(); + } +}; + +} // namespace + +void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) { + patterns.add( + patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp b/src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp new file mode 100644 index 0000000..490c038 --- /dev/null +++ b/src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp @@ -0,0 +1,53 @@ +#include "Patterns.hpp" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct PimConstantFoldingPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) + + StringRef getArgument() const override { return "pim-constant-folding-pass"; } + StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } + + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet owningPatterns(context); + for (auto* dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + for (RegisteredOperationName op : context->getRegisteredOperations()) + op.getCanonicalizationPatterns(owningPatterns, context); + + populateConstantFoldingConstantPatterns(owningPatterns); + populateConstantFoldingSubviewPatterns(owningPatterns); + + patterns = std::make_shared(std::move(owningPatterns)); + return success(); + } + + void runOnOperation() override { + GreedyRewriteConfig config; + config.enableFolding(); + if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) { + signalPassFailure(); + return; + } + + dumpModule(getOperation(), "pim2_folded"); + } + + std::shared_ptr patterns; +}; + +} // namespace + +std::unique_ptr createPimConstantFoldingPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimHostVerificationPass.cpp b/src/PIM/Pass/PimHostVerificationPass.cpp index c82f9ef..827fa56 100644 --- a/src/PIM/Pass/PimHostVerificationPass.cpp +++ b/src/PIM/Pass/PimHostVerificationPass.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" @@ -26,37 +27,12 @@ static bool isAddressOnlyHostOp(Operation* op) { spatial::SpatChannelNewOp>(op); } -static bool isHostAddressableValue(Value value) { - while (true) { - if (auto blockArg = dyn_cast(value)) - return isa(blockArg.getOwner()->getParentOp()); - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return false; - - if (isa(definingOp)) - return true; - - if (auto subviewOp = dyn_cast(definingOp)) { - value = subviewOp.getSource(); - continue; - } - if (auto castOp = dyn_cast(definingOp)) { - value = castOp.getSource(); - continue; - } - if (auto collapseOp = dyn_cast(definingOp)) { - value = collapseOp.getSrc(); - continue; - } - if (auto expandOp = dyn_cast(definingOp)) { - value = expandOp.getSrc(); - continue; - } - +static bool isCodegenAddressableValue(Value value) { + auto resolvedAddress = resolveContiguousAddress(value); + if (failed(resolvedAddress)) return false; - } + return isa(resolvedAddress->base) + || isa(resolvedAddress->base.getDefiningOp()); } struct PimHostVerificationPass : PassWrapper> { @@ -80,7 +56,7 @@ struct PimHostVerificationPass : PassWrapper(&op)) { - if (failed(verifyCoreWeights(moduleOp, coreOp))) + if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp))) hasFailure = true; continue; } @@ -139,8 +115,27 @@ private: static LogicalResult verifyReturnOp(func::ReturnOp returnOp) { bool hasFailure = false; for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) { - if (!isHostAddressableValue(operand)) { - returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage"; + if (!isCodegenAddressableValue(operand)) { + returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage"; + hasFailure = true; + } + } + return success(!hasFailure); + } + + static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) { + bool hasFailure = false; + for (Operation& op : coreOp.getBody().front()) { + if (isa(op)) + continue; + + for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { + if (!isa(operand.getType())) + continue; + if (succeeded(resolveContiguousAddress(operand))) + continue; + + op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; hasFailure = true; } } @@ -160,10 +155,10 @@ private: } static LogicalResult verifyAddressOnlySource(Operation* op, Value source) { - if (isHostAddressableValue(source)) + if (isCodegenAddressableValue(source)) return success(); - op->emitOpError("depends on a value that still requires host-side execution"); + op->emitOpError("depends on a value that is not backed by contiguous addressable storage"); return failure(); } };