reformat code

This commit is contained in:
NiccoloN
2026-03-23 21:25:51 +01:00
parent 93e20c1dfc
commit a4f3eed3e0
12 changed files with 101 additions and 107 deletions

View File

@@ -244,7 +244,7 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress{value, byteOffset};
return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
@@ -293,7 +293,7 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress{value, byteOffset};
return ResolvedContiguousAddress {value, byteOffset};
return failure();
}

View File

@@ -389,8 +389,8 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
[](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);

View File

@@ -70,9 +70,8 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
});
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();

View File

@@ -58,8 +58,14 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
loc,
rhsRowType,
rhsSlice,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
@@ -89,10 +95,15 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
Value result = ONNXTransposeOp::create(
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
loc,
gemmExpandedType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result);
return success();

View File

@@ -114,8 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Reshape>(ctx);
}
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
} // namespace onnx_mlir

View File

@@ -70,8 +70,8 @@ Operation* getEarliestUserWithinBlock(mlir::Value value) {
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses =
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
return {operand, std::distance(operand.use_begin(), operand.use_end())};
});
return {operand, std::distance(operand.use_begin(), operand.use_end())};
});
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}

View File

@@ -186,7 +186,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
};
template <typename OpTy>
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
struct BinaryDstOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}

View File

@@ -1,5 +1,4 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
@@ -107,9 +106,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
return info;
}
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth) {
int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)

View File

@@ -1,10 +1,9 @@
#include "Patterns.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;

View File

@@ -1,12 +1,11 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -478,10 +477,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
} // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantMemCpPattern>(patterns.getContext());
patterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -1,6 +1,5 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -20,10 +19,12 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
const bool splitSrc =
succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst =
succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
@@ -62,13 +63,13 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset =
srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
const int64_t dstByteOffset =
dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
@@ -87,30 +88,25 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
@@ -123,30 +119,25 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
@@ -201,11 +192,13 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
auto newGlobal =
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
auto newGetGlobal =
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());

View File

@@ -98,16 +98,15 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
@@ -127,8 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace
std::unique_ptr<Pass> createMaterializeConstantsPass() {
return std::make_unique<MaterializeConstantsPass>();
}
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
} // namespace onnx_mlir