centralize logic for materializing contiguous memory into bufferization

fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
This commit is contained in:
NiccoloN
2026-05-30 15:54:24 +02:00
parent 2d5b03c08f
commit ff36729140
29 changed files with 642 additions and 822 deletions
-26
View File
@@ -144,32 +144,6 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
}];
}
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
let arguments = (ins
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceTargetMutable();
}
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
}];
}
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory";
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
@@ -10,8 +11,8 @@ using namespace bufferization;
namespace onnx_mlir::pim {
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -30,7 +31,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
.getOutput();
}
Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
@@ -5,8 +5,9 @@
namespace onnx_mlir::pim {
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
mlir::Value
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
mlir::Value value,
@@ -6,6 +6,8 @@ add_pim_library(OMPimBufferization
PimBufferizationPass.cpp
BufferizationUtils.hpp
BufferizationUtils.cpp
ContiguityPatterns.hpp
ContiguityPatterns.cpp
OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp
Common.hpp
@@ -0,0 +1,343 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
}
auto value = cast<Value>(baseOffset);
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
}
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
SmallVector<Value> indices;
indices.reserve(shape.size());
Value remaining = linearIndex;
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
}
return indices;
}
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = cast<IntegerAttr>(attr);
if (integerAttr.getInt() == 0)
return extraOffset;
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
}
auto value = cast<Value>(baseOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
}
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<Value> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
if (dim + 1 < info.sizes.size()) {
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1));
}
else {
chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
}
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static Value buildContiguousChunk(
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(copyShape.size());
chunkSizes.reserve(copyShape.size());
chunkStrides.reserve(copyShape.size());
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
chunkStrides.push_back(rewriter.getIndexAttr(1));
}
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
if (!hasByteSizedElementType(sourceType.getElementType()))
return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
&& (splitDst || dstShapeMatchesCopyShape)) {
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
return success();
}
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
return success();
}
struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getTarget());
return success();
}
};
struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
return success();
}
};
struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getHostTarget(),
copyOp.getDeviceSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
return success();
}
};
} // namespace
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
patterns.getContext());
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir::pim {
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir::pim
@@ -47,32 +47,6 @@ struct MemCopyHostToDevOpInterface
}
};
struct MemCopyHostToDevBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
if (failed(deviceTargetOpt))
return failure();
auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
if (failed(hostSourceOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
memCopyHostToDevOp,
deviceTargetOpt->getType(),
*deviceTargetOpt,
*hostSourceOpt,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
};
struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op,
@@ -151,8 +125,9 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
if (failed(outputBufferOpt))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimReceiveOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
rewriter, op, contiguousOutput.getType(), contiguousOutput, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
return success();
}
};
@@ -173,15 +148,16 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt))
return failure();
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter));
}
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimConcatOp>(
rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt);
rewriter, op, contiguousOutput.getType(), concatOp.getAxisAttr(), ValueRange(inputs), contiguousOutput);
return success();
}
};
@@ -206,7 +182,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreId());
return success();
@@ -431,8 +407,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
@@ -475,8 +451,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
@@ -514,9 +490,9 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
if (failed(outputBufferOpt))
return failure();
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
@@ -547,9 +523,9 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
if (failed(outputBufferOpt))
return failure();
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
@@ -583,8 +559,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
return success();
@@ -599,7 +575,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
@@ -6,14 +6,15 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
@@ -40,6 +41,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
};
} // namespace
@@ -84,6 +86,20 @@ void PimBufferizationPass::runOnOperation() {
return;
}
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
// Later PIM codegen passes may verify this invariant but must not repair it.
RewritePatternSet contiguityPatterns(ctx);
populatePimContiguityNormalizationPatterns(contiguityPatterns);
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
signalPassFailure();
return;
}
if (failed(verifyContiguousRuntimeOperands(moduleOp))) {
signalPassFailure();
return;
}
annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug
@@ -108,6 +124,75 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
}
LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp moduleOp) const {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
auto verifyOperand = [&](Value operand, unsigned operandIndex) {
if (!isa<BaseMemRefType>(operand.getType()))
return;
if (succeeded(resolveContiguousAddress(operand)) || succeeded(compileContiguousAddressExpr(operand)))
return;
op->emitOpError() << "operand #" << operandIndex
<< " is not backed by contiguous addressable storage after PIM bufferization";
hasFailure = true;
};
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
verifyOperand(memCopyOp.getTarget(), 0);
verifyOperand(memCopyOp.getSource(), 1);
return;
}
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
verifyOperand(loadOp.getDeviceTarget(), 2);
verifyOperand(loadOp.getHostSource(), 3);
return;
}
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
verifyOperand(storeOp.getHostTarget(), 2);
verifyOperand(storeOp.getDeviceSource(), 3);
return;
}
if (auto sendOp = dyn_cast<PimSendOp>(op)) {
verifyOperand(sendOp.getInput(), 0);
return;
}
if (auto receiveOp = dyn_cast<PimReceiveOp>(op)) {
verifyOperand(receiveOp.getOutputBuffer(), 0);
return;
}
if (auto concatOp = dyn_cast<PimConcatOp>(op)) {
verifyOperand(concatOp.getOutputBuffer(), 0);
for (auto inputAndIndex : llvm::enumerate(concatOp.getInputs()))
verifyOperand(inputAndIndex.value(), inputAndIndex.index() + 1);
return;
}
if (isa<PimTransposeOp,
PimVMMOp,
PimVVAddOp,
PimVVSubOp,
PimVVMulOp,
PimVVMaxOp,
PimVVDMulOp,
PimVAvgOp,
PimVReluOp,
PimVTanhOp,
PimVSigmOp,
PimVSoftmaxOp>(op)) {
for (auto operandAndIndex : llvm::enumerate(op->getOperands())) {
if (auto vmmOp = dyn_cast<PimVMMOp>(op); vmmOp && operandAndIndex.index() == 0)
continue;
verifyOperand(operandAndIndex.value(), operandAndIndex.index());
}
}
});
if (hasFailure) {
moduleOp.emitError("PIM bufferization must fully normalize executable runtime operand contiguity before codegen");
return failure();
}
return success();
}
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir
@@ -22,6 +22,10 @@ using namespace onnx_mlir::compact_asm;
namespace onnx_mlir {
namespace {
// This pass assumes bufferization has already normalized executable PIM
// operands. It only reuses compatible local allocations with non-overlapping
// lifetimes; it does not repair memory contiguity.
struct CoalescingReportRow {
uint64_t numCandidates = 0;
uint64_t numSkipped = 0;