reformat code

This commit is contained in:
NiccoloN
2026-03-23 21:25:51 +01:00
parent 93e20c1dfc
commit a4f3eed3e0
12 changed files with 101 additions and 107 deletions

View File

@@ -244,7 +244,7 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
while (true) { while (true) {
if (isa<BlockArgument>(value)) if (isa<BlockArgument>(value))
return ResolvedContiguousAddress{value, byteOffset}; return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp(); Operation* definingOp = value.getDefiningOp();
if (!definingOp) if (!definingOp)
@@ -293,7 +293,7 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
} }
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp)) if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress{value, byteOffset}; return ResolvedContiguousAddress {value, byteOffset};
return failure(); return failure();
} }

View File

@@ -389,8 +389,8 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
size_t totalElements = srcType.getNumElements(); size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i]. // Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm = SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); }); [](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]] // Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank); SmallVector<int64_t> dstShape(rank);

View File

@@ -70,9 +70,8 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>(); target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) { target.addDynamicallyLegalOp<ONNXMatMulOp>(
return cast<ShapedType>(op.getY().getType()).getRank() != 2; [](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
});
target.addIllegalOp<ONNXGemmOp>(); target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>(); target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>(); target.addIllegalOp<ONNXLRNOp>();

View File

@@ -58,8 +58,14 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice = Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides); tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create( Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}}); loc,
rhsRowType,
rhsSlice,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter, auto gemmOp = ONNXGemmOp::create(rewriter,
loc, loc,
@@ -89,10 +95,15 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.setInsertionPointAfter(concatComputeOp); rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0); Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create( Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}}); loc,
Value result = ONNXTransposeOp::create( gemmExpandedType,
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); gemmOut,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, result);
return success(); return success();

View File

@@ -114,8 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace } // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
patterns.insert<Reshape>(ctx);
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -186,7 +186,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
}; };
template <typename OpTy> template <typename OpTy>
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> { struct BinaryDstOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }

View File

@@ -1,5 +1,4 @@
#include "Common.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
@@ -107,9 +106,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
return info; return info;
} }
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, int64_t
ArrayRef<int64_t> outerIndices, getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices; SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size()); sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim) for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)

View File

@@ -1,10 +1,9 @@
#include "Patterns.hpp"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory> #include <memory>
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -1,12 +1,11 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -478,10 +477,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
} // namespace } // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern, patterns
FoldConstantAllocPattern, .add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
FoldConstantCoreMapPattern, patterns.getContext());
FoldConstantMemCpPattern>(patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,6 +1,5 @@
#include "../Common.hpp" #include "../Common.hpp"
#include "../Patterns.hpp" #include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -20,9 +19,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
CreateCopyOp createCopyOp) { CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src); auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst); auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview) const bool splitSrc =
succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview) const bool splitDst =
succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst) if (!splitSrc && !splitDst)
return failure(); return failure();
@@ -62,13 +63,13 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
rewriter.setInsertionPoint(copyOp); rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices = SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides); outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = srcOffset const int64_t srcByteOffset =
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) srcOffset
: linearIndex * sliceBytes); + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
const int64_t dstByteOffset = dstOffset const int64_t dstByteOffset =
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) dstOffset
: linearIndex * sliceBytes); + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType, createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst, splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src, splitSrc ? srcSubview->source : src,
@@ -87,22 +88,17 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
if (!copyOp->getParentOfType<pim::PimCoreOp>()) if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure(); return failure();
auto status = auto status = rewriteSubviewCopyLikeOp(
rewriteSubviewCopyLikeOp(copyOp, copyOp,
copyOp.getTarget(), copyOp.getTarget(),
copyOp.getSource(), copyOp.getSource(),
copyOp.getTargetOffset(), copyOp.getTargetOffset(),
copyOp.getSourceOffset(), copyOp.getSourceOffset(),
copyOp.getSize(), copyOp.getSize(),
rewriter, rewriter,
[&](MemRefType resultType, [&](
Value dst, MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value src, pim::PimMemCopyOp::create(rewriter,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyOp::create(
rewriter,
copyOp.getLoc(), copyOp.getLoc(),
resultType, resultType,
dst, dst,
@@ -123,22 +119,17 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status = auto status = rewriteSubviewCopyLikeOp(
rewriteSubviewCopyLikeOp(copyOp, copyOp,
copyOp.getDeviceTarget(), copyOp.getDeviceTarget(),
copyOp.getHostSource(), copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(), copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(), copyOp.getHostSourceOffset(),
copyOp.getSize(), copyOp.getSize(),
rewriter, rewriter,
[&](MemRefType resultType, [&](
Value dst, MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value src, pim::PimMemCopyHostToDevOp::create(rewriter,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(), copyOp.getLoc(),
resultType, resultType,
dst, dst,
@@ -201,11 +192,13 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
} }
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); auto newGlobal =
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal); markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp); rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); auto newGetGlobal =
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal); markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); rewriter.replaceOp(subviewOp, newGetGlobal.getResult());

View File

@@ -98,16 +98,15 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
if (contiguousType != originalType) if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc); deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter, auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
rewriter,
op.getLoc(), op.getLoc(),
originalType, originalType,
deviceDst, deviceDst,
getGlobalOp.getResult(), getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr( rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
static_cast<int32_t>(resolvedAddress->byteOffset)), rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult(); cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult()); operand.set(hostToDevCopy.getResult());
@@ -127,8 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace } // namespace
std::unique_ptr<Pass> createMaterializeConstantsPass() { std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
return std::make_unique<MaterializeConstantsPass>();
}
} // namespace onnx_mlir } // namespace onnx_mlir