From a4f3eed3e0441e553233dfd90a9303f503828cdb Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 23 Mar 2026 21:25:51 +0100 Subject: [PATCH] reformat code --- src/PIM/Common/PimCommon.cpp | 4 +- src/PIM/Compiler/PimCodeGen.cpp | 4 +- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 5 +- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 23 +++- .../ONNXToSpatial/Patterns/Tensor/Reshape.cpp | 4 +- src/PIM/Conversion/SpatialToPim/Common.cpp | 4 +- .../OpBufferizationInterfaces.cpp | 3 +- src/PIM/Pass/Pim/ConstantFolding/Common.cpp | 6 +- .../ConstantFolding/ConstantFoldingPass.cpp | 3 +- .../Pim/ConstantFolding/Patterns/Constant.cpp | 12 +- .../Pim/ConstantFolding/Patterns/Subview.cpp | 117 ++++++++---------- src/PIM/Pass/Pim/MaterializeConstantsPass.cpp | 23 ++-- 12 files changed, 101 insertions(+), 107 deletions(-) diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 6876d38..2de9861 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -244,7 +244,7 @@ FailureOr resolveContiguousAddress(Value value) { while (true) { if (isa(value)) - return ResolvedContiguousAddress{value, byteOffset}; + return ResolvedContiguousAddress {value, byteOffset}; Operation* definingOp = value.getDefiningOp(); if (!definingOp) @@ -293,7 +293,7 @@ FailureOr resolveContiguousAddress(Value value) { } if (isa(definingOp)) - return ResolvedContiguousAddress{value, byteOffset}; + return ResolvedContiguousAddress {value, byteOffset}; return failure(); } diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 192d15a..7c1f6e9 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -389,8 +389,8 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { size_t totalElements = srcType.getNumElements(); // Read permutation. Destination dim i corresponds to source dim perm[i]. - SmallVector perm = - map_to_vector(transposeOp.getPermutation().getAsRange(), [](auto attr) -> int64_t { return attr.getInt(); }); + SmallVector perm = map_to_vector(transposeOp.getPermutation().getAsRange(), + [](auto attr) -> int64_t { return attr.getInt(); }); // Destination shape: dstShape[i] = srcShape[perm[i]] SmallVector dstShape(rank); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 165d9f1..960939f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -70,9 +70,8 @@ void ONNXToSpatialPass::runOnOperation() { ConversionTarget target(*ctx); target.addLegalDialect(); - target.addDynamicallyLegalOp([](ONNXMatMulOp op) { - return cast(op.getY().getType()).getRank() != 2; - }); + target.addDynamicallyLegalOp( + [](ONNXMatMulOp op) { return cast(op.getY().getType()).getRank() != 2; }); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index d8dda1d..2ec810e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -58,8 +58,14 @@ struct MatMulRank3ToGemm : OpRewritePattern { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value rhsSlice = tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides); - Value rhsRow = tensor::CollapseShapeOp::create( - rewriter, loc, rhsRowType, rhsSlice, SmallVector{{0}, {1, 2}}); + Value rhsRow = tensor::CollapseShapeOp::create(rewriter, + loc, + rhsRowType, + rhsSlice, + SmallVector { + {0}, + {1, 2} + }); auto gemmOp = ONNXGemmOp::create(rewriter, loc, @@ -89,10 +95,15 @@ struct MatMulRank3ToGemm : OpRewritePattern { rewriter.setInsertionPointAfter(concatComputeOp); Value gemmOut = concatComputeOp.getResult(0); - Value gemmExpanded = tensor::ExpandShapeOp::create( - rewriter, loc, gemmExpandedType, gemmOut, SmallVector{{0, 1}, {2}}); - Value result = ONNXTransposeOp::create( - rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); + Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter, + loc, + gemmExpandedType, + gemmOut, + SmallVector { + {0, 1}, + {2} + }); + Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); rewriter.replaceOp(matmulOp, result); return success(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index dcf200d..befa1c5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -114,8 +114,6 @@ struct Reshape : OpConversionPattern { } // namespace -void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); -} +void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 073a8b6..96aa454 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -70,8 +70,8 @@ Operation* getEarliestUserWithinBlock(mlir::Value value) { SmallVector getOpOperandsSortedByUses(Operation* operation) { auto operandsAndUses = map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair { - return {operand, std::distance(operand.use_begin(), operand.use_end())}; - }); + return {operand, std::distance(operand.use_begin(), operand.use_end())}; + }); sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; }); return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index f150af0..780d1e2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -186,7 +186,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel -struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel, OpTy> { +struct BinaryDstOpBufferizeInterface +: DstBufferizableOpInterfaceExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); } diff --git a/src/PIM/Pass/Pim/ConstantFolding/Common.cpp b/src/PIM/Pass/Pim/ConstantFolding/Common.cpp index 14df695..16f121d 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Common.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Common.cpp @@ -1,5 +1,4 @@ #include "Common.hpp" - #include "src/Accelerators/PIM/Common/PimCommon.hpp" using namespace mlir; @@ -107,9 +106,8 @@ FailureOr getStaticSubviewInfo(Value value) { return info; } -int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, - ArrayRef outerIndices, - int64_t elementByteWidth) { +int64_t +getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef outerIndices, int64_t elementByteWidth) { SmallVector sourceIndices; sourceIndices.reserve(info.sourceShape.size()); for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim) diff --git a/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp b/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp index 83ac773..e237de4 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp @@ -1,10 +1,9 @@ -#include "Patterns.hpp" - #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include +#include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" using namespace mlir; diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp index 0c2e207..531d5c3 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp @@ -1,12 +1,11 @@ -#include "../Common.hpp" -#include "../Patterns.hpp" - #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Matchers.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "../Common.hpp" +#include "../Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -478,10 +477,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { } // namespace void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { - patterns.add(patterns.getContext()); + patterns + .add( + patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp index 6ee63a4..32e5f43 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp @@ -1,6 +1,5 @@ #include "../Common.hpp" #include "../Patterns.hpp" - #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -20,10 +19,12 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, CreateCopyOp createCopyOp) { auto srcSubview = getStaticSubviewInfo(src); auto dstSubview = getStaticSubviewInfo(dst); - const bool splitSrc = succeeded(srcSubview) - && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); - const bool splitDst = succeeded(dstSubview) - && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); + const bool splitSrc = + succeeded(srcSubview) + && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); + const bool splitDst = + succeeded(dstSubview) + && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); if (!splitSrc && !splitDst) return failure(); @@ -62,13 +63,13 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, rewriter.setInsertionPoint(copyOp); for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices = - outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); - const int64_t srcByteOffset = srcOffset - + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) - : linearIndex * sliceBytes); - const int64_t dstByteOffset = dstOffset - + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) - : linearIndex * sliceBytes); + outerShape.empty() ? SmallVector {} : delinearizeIndex(linearIndex, outerShape, outerStrides); + const int64_t srcByteOffset = + srcOffset + + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); + const int64_t dstByteOffset = + dstOffset + + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); createCopyOp(splitDst ? cast(dstSubview->source.getType()) : dstType, splitDst ? dstSubview->source : dst, splitSrc ? srcSubview->source : src, @@ -87,30 +88,25 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern if (!copyOp->getParentOfType()) return failure(); - auto status = - rewriteSubviewCopyLikeOp(copyOp, - copyOp.getTarget(), - copyOp.getSource(), - copyOp.getTargetOffset(), - copyOp.getSourceOffset(), - copyOp.getSize(), - rewriter, - [&](MemRefType resultType, - Value dst, - Value src, - int64_t dstByteOffset, - int64_t srcByteOffset, - int64_t sliceBytes) { - pim::PimMemCopyOp::create( - rewriter, - copyOp.getLoc(), - resultType, - dst, - src, - rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - }); + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getTarget(), + copyOp.getSource(), + copyOp.getTargetOffset(), + copyOp.getSourceOffset(), + copyOp.getSize(), + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + pim::PimMemCopyOp::create(rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); if (failed(status)) return failure(); @@ -123,30 +119,25 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - }); + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getDeviceTarget(), + copyOp.getHostSource(), + copyOp.getDeviceTargetOffset(), + copyOp.getHostSourceOffset(), + copyOp.getSize(), + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + pim::PimMemCopyHostToDevOp::create(rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); if (failed(status)) return failure(); @@ -201,11 +192,13 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern(resolvedAddress->byteOffset)), - rewriter.getI32IntegerAttr( - static_cast(totalBytes))); + auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create( + rewriter, + op.getLoc(), + originalType, + deviceDst, + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(static_cast(resolvedAddress->byteOffset)), + rewriter.getI32IntegerAttr(static_cast(totalBytes))); cachedByType[originalType] = hostToDevCopy.getResult(); operand.set(hostToDevCopy.getResult()); @@ -127,8 +126,6 @@ struct MaterializeConstantsPass : PassWrapper createMaterializeConstantsPass() { - return std::make_unique(); -} +std::unique_ptr createMaterializeConstantsPass() { return std::make_unique(); } } // namespace onnx_mlir