reformat code

This commit is contained in:
NiccoloN
2026-03-23 21:25:51 +01:00
parent 93e20c1dfc
commit a4f3eed3e0
12 changed files with 101 additions and 107 deletions

View File

@@ -1,5 +1,4 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
@@ -107,9 +106,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
return info;
}
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth) {
int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)

View File

@@ -1,10 +1,9 @@
#include "Patterns.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;

View File

@@ -1,12 +1,11 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -478,10 +477,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
} // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantMemCpPattern>(patterns.getContext());
patterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -1,6 +1,5 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -20,10 +19,12 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
const bool splitSrc =
succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst =
succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
@@ -62,13 +63,13 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset =
srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
const int64_t dstByteOffset =
dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
@@ -87,30 +88,25 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
@@ -123,30 +119,25 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
@@ -201,11 +192,13 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
auto newGlobal =
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
auto newGetGlobal =
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());

View File

@@ -98,16 +98,15 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
@@ -127,8 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace
std::unique_ptr<Pass> createMaterializeConstantsPass() {
return std::make_unique<MaterializeConstantsPass>();
}
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
} // namespace onnx_mlir