compact memory contiguity with for loops
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-31 18:47:59 +02:00
parent ab63498f3f
commit b678e55d3c
14 changed files with 550 additions and 331 deletions
+5 -3
View File
@@ -176,10 +176,10 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region within the same memory space";
let arguments = (ins
Index:$targetOffset,
Index:$sourceOffset,
PimTensor:$target,
PimTensor:$source,
I32Attr:$targetOffset,
I32Attr:$sourceOffset,
I32Attr:$size
);
@@ -194,7 +194,9 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
`[` $targetOffset `,` $sourceOffset `]`
`(` $target `,` $source `)` attr-dict
`:` type($target) `,` type($source) `->` type($output)
}];
}
@@ -19,14 +19,15 @@ Value materializeContiguousInputMemRef(Value memrefValue, Location loc, Rewriter
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
zeroOffset,
zeroOffset,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
@@ -1,7 +1,3 @@
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_pim_library(OMPimBufferization
PimBufferizationPass.cpp
BufferizationUtils.hpp
@@ -15,9 +11,6 @@ add_pim_library(OMPimBufferization
EXCLUDE_FROM_OM_LIBS
DEPENDS
PimBufferizationIncGen
LINK_LIBS PUBLIC
OMPimCommon
PimOps
@@ -3,220 +3,360 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
struct ByteOffsetTerm {
Value value;
int64_t scale = 0;
};
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
struct ByteOffsetExpr {
int64_t constant = 0;
SmallVector<ByteOffsetTerm> terms;
};
struct CopyEndpointPlan {
Value base;
MemRefType originalType;
MemRefType baseType;
ByteOffsetExpr offset;
};
struct CopyLoopPlan {
SmallVector<int64_t> outerShape;
int64_t chunkBytes = 0;
ByteOffsetExpr targetBaseOffset;
ByteOffsetExpr sourceBaseOffset;
SmallVector<int64_t> targetOuterByteStrides;
SmallVector<int64_t> sourceOuterByteStrides;
};
struct CopyRewritePlan {
enum class Kind {
Direct,
Loop
} kind = Kind::Direct;
CopyEndpointPlan target;
CopyEndpointPlan source;
int64_t directBytes = 0;
CopyLoopPlan loop;
};
static bool isViewLike(Value value) {
Operation* defOp = value.getDefiningOp();
return defOp
&& isa<memref::SubViewOp,
memref::ReinterpretCastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
memref::CastOp>(defOp);
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
template <typename CopyOp>
static bool isNormalizedCopyLikeOp(CopyOp copyOp, Value target, Value source, Value targetOffset, Value sourceOffset) {
auto targetType = dyn_cast<MemRefType>(target.getType());
auto sourceType = dyn_cast<MemRefType>(source.getType());
return targetType && sourceType && !isViewLike(target) && !isViewLike(source) && targetOffset.getType().isIndex()
&& sourceOffset.getType().isIndex() && copyOp.getSize() > 0;
}
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
static void appendTerm(ByteOffsetExpr& expr, Value value, int64_t scale) {
if (scale != 0)
expr.terms.push_back(ByteOffsetTerm {value, scale});
}
static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return failure();
if (llvm::any_of(strides, ShapedType::isDynamic))
return failure();
return strides;
}
static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
return failure();
return static_cast<int64_t>(getShapedTypeSizeInBytes(type));
}
static FailureOr<SmallVector<int64_t>>
inferLogicalCopyShape(MemRefType targetType, MemRefType sourceType, int64_t size) {
if (!targetType.hasStaticShape() || !sourceType.hasStaticShape())
return failure();
if (targetType.getElementType() != sourceType.getElementType() || targetType.getRank() != sourceType.getRank())
return failure();
auto targetBytes = getShapedByteSize(targetType);
auto sourceBytes = getShapedByteSize(sourceType);
if (failed(targetBytes) || failed(sourceBytes))
return failure();
bool targetMatches = *targetBytes == size;
bool sourceMatches = *sourceBytes == size;
if (targetMatches && sourceMatches && targetType.getShape() != sourceType.getShape())
return failure();
if (targetMatches)
return SmallVector<int64_t>(targetType.getShape().begin(), targetType.getShape().end());
if (sourceMatches)
return SmallVector<int64_t>(sourceType.getShape().begin(), sourceType.getShape().end());
return failure();
}
static FailureOr<int64_t> getContiguousSuffixRank(MemRefType type, ArrayRef<int64_t> copyShape) {
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())
|| type.getRank() != static_cast<int64_t>(copyShape.size()))
return failure();
auto strides = getStaticMemRefStrides(type);
if (failed(strides))
return failure();
int64_t expectedStride = 1;
int64_t contiguousSuffixRank = 0;
for (int64_t dim = type.getRank() - 1; dim >= 0; --dim) {
if ((*strides)[dim] != expectedStride)
break;
++contiguousSuffixRank;
expectedStride *= copyShape[dim];
}
return contiguousSuffixRank;
}
static FailureOr<CopyEndpointPlan> analyzeCopyEndpoint(Value value, Value initialByteOffset, MemRefType logicalType) {
if (!logicalType.hasStaticShape() || !hasByteSizedElementType(logicalType.getElementType()))
return failure();
CopyEndpointPlan endpoint;
endpoint.base = value;
endpoint.originalType = logicalType;
appendTerm(endpoint.offset, initialByteOffset, 1);
while (true) {
if (auto castOp = endpoint.base.getDefiningOp<memref::CastOp>()) {
endpoint.base = castOp.getSource();
continue;
}
if (auto collapseOp = endpoint.base.getDefiningOp<memref::CollapseShapeOp>()) {
endpoint.base = collapseOp.getSrc();
continue;
}
if (auto expandOp = endpoint.base.getDefiningOp<memref::ExpandShapeOp>()) {
endpoint.base = expandOp.getSrc();
continue;
}
if (auto reinterpretOp = endpoint.base.getDefiningOp<memref::ReinterpretCastOp>()) {
endpoint.base = reinterpretOp.getSource();
continue;
}
auto subviewOp = endpoint.base.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
break;
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
if (!sourceType || !sourceType.hasStaticShape() || !hasByteSizedElementType(sourceType.getElementType()))
return failure();
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return failure();
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
for (auto [offset, stride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
int64_t byteScale = stride * elementByteWidth;
if (auto attr = dyn_cast<Attribute>(offset)) {
endpoint.offset.constant += cast<IntegerAttr>(attr).getInt() * byteScale;
continue;
}
appendTerm(endpoint.offset, cast<Value>(offset), byteScale);
}
endpoint.base = subviewOp.getSource();
}
auto value = cast<Value>(baseOffset);
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
endpoint.baseType = dyn_cast<MemRefType>(endpoint.base.getType());
if (!endpoint.baseType)
return failure();
return endpoint;
}
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
static FailureOr<CopyRewritePlan>
analyzeCopyRewrite(Value target, Value source, Value targetOffset, Value sourceOffset, int64_t size) {
auto targetType = dyn_cast<MemRefType>(target.getType());
auto sourceType = dyn_cast<MemRefType>(source.getType());
if (!targetType || !sourceType || size <= 0)
return failure();
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
auto logicalCopyShape = inferLogicalCopyShape(targetType, sourceType, size);
if (failed(logicalCopyShape))
return failure();
auto targetPlan = analyzeCopyEndpoint(target, targetOffset, targetType);
auto sourcePlan = analyzeCopyEndpoint(source, sourceOffset, sourceType);
if (failed(targetPlan) || failed(sourcePlan))
return failure();
auto targetSuffixRank = getContiguousSuffixRank(targetType, *logicalCopyShape);
auto sourceSuffixRank = getContiguousSuffixRank(sourceType, *logicalCopyShape);
if (failed(targetSuffixRank) || failed(sourceSuffixRank))
return failure();
CopyRewritePlan plan;
plan.target = *targetPlan;
plan.source = *sourcePlan;
int64_t contiguousSuffixRank = std::min(*targetSuffixRank, *sourceSuffixRank);
if (contiguousSuffixRank == static_cast<int64_t>(logicalCopyShape->size())) {
plan.kind = CopyRewritePlan::Kind::Direct;
plan.directBytes = size;
return plan;
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
auto targetStrides = getStaticMemRefStrides(targetType);
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(targetStrides) || failed(sourceStrides))
return failure();
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(targetType.getElementType()));
plan.kind = CopyRewritePlan::Kind::Loop;
plan.loop.targetBaseOffset = plan.target.offset;
plan.loop.sourceBaseOffset = plan.source.offset;
plan.loop.outerShape.assign(logicalCopyShape->begin(), logicalCopyShape->end() - contiguousSuffixRank);
SmallVector<int64_t> chunkShape(logicalCopyShape->end() - contiguousSuffixRank, logicalCopyShape->end());
plan.loop.chunkBytes = getNumElements(chunkShape) * elementByteWidth;
for (int64_t stride : ArrayRef<int64_t>(*targetStrides).take_front(plan.loop.outerShape.size()))
plan.loop.targetOuterByteStrides.push_back(stride * elementByteWidth);
for (int64_t stride : ArrayRef<int64_t>(*sourceStrides).take_front(plan.loop.outerShape.size()))
plan.loop.sourceOuterByteStrides.push_back(stride * elementByteWidth);
if (plan.loop.chunkBytes <= 0)
return failure();
return plan;
}
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
static Value createIndexConstant(PatternRewriter& rewriter, Operation* anchorOp, int64_t value) {
return getOrCreateIndexConstant(rewriter, anchorOp, value);
}
static Value addIndexValues(PatternRewriter& rewriter, Location loc, Value lhs, Value rhs) {
if (auto constant = getConstantIntValue(lhs); constant && *constant == 0)
return rhs;
if (auto constant = getConstantIntValue(rhs); constant && *constant == 0)
return lhs;
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
}
static Value mulIndexValue(PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value value, int64_t scale) {
if (scale == 0)
return createIndexConstant(rewriter, anchorOp, 0);
if (scale == 1)
return value;
Value scaleValue = createIndexConstant(rewriter, anchorOp, scale);
return arith::MulIOp::create(rewriter, loc, value, scaleValue).getResult();
}
static Value
materializeByteOffset(PatternRewriter& rewriter, Location loc, Operation* anchorOp, const ByteOffsetExpr& expr) {
Value result = createIndexConstant(rewriter, anchorOp, expr.constant);
for (const ByteOffsetTerm& term : expr.terms)
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, term.value, term.scale));
return result;
}
static SmallVector<Value> materializeDelinearizedIndices(
PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value linearIndex, ArrayRef<int64_t> shape) {
SmallVector<Value> indices;
indices.reserve(shape.size());
if (shape.empty())
return indices;
auto rowMajorStrides = computeRowMajorStrides(shape);
Value remaining = linearIndex;
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
for (auto [dim, stride] : llvm::enumerate(rowMajorStrides)) {
if (dim + 1 == rowMajorStrides.size()) {
indices.push_back(remaining);
break;
}
Value strideValue = createIndexConstant(rewriter, anchorOp, stride);
Value index = arith::DivUIOp::create(rewriter, loc, remaining, strideValue);
indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
remaining = arith::RemUIOp::create(rewriter, loc, remaining, strideValue);
}
return indices;
}
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = cast<IntegerAttr>(attr);
if (integerAttr.getInt() == 0)
return extraOffset;
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
}
auto value = cast<Value>(baseOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
}
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<Value> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
if (dim + 1 < info.sizes.size()) {
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1));
}
else {
chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
}
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static Value buildContiguousChunk(
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(copyShape.size());
chunkSizes.reserve(copyShape.size());
chunkStrides.reserve(copyShape.size());
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
chunkStrides.push_back(rewriter.getIndexAttr(1));
}
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
static Value materializeOuterByteOffset(PatternRewriter& rewriter,
Location loc,
Operation* anchorOp,
const ByteOffsetExpr& baseOffset,
ArrayRef<Value> outerIndices,
ArrayRef<int64_t> outerByteStrides) {
Value result = materializeByteOffset(rewriter, loc, anchorOp, baseOffset);
for (auto [index, stride] : llvm::zip_equal(outerIndices, outerByteStrides))
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, index, stride));
return result;
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
Value target,
Value source,
Value targetOffset,
Value sourceOffset,
Value replacementValue,
CreateCopyOp createCopyOp,
PatternRewriter& rewriter) {
if (isNormalizedCopyLikeOp(copyOp, target, source, targetOffset, sourceOffset))
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
auto plan = analyzeCopyRewrite(target, source, targetOffset, sourceOffset, copyOp.getSize());
if (failed(plan))
return failure();
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
if (!hasByteSizedElementType(sourceType.getElementType()))
return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
&& (splitDst || dstShapeMatchesCopyShape)) {
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
Location loc = copyOp.getLoc();
Operation* anchorOp = copyOp.getOperation();
if (plan->kind == CopyRewritePlan::Kind::Direct) {
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
auto newCopyOp = createCopyOp(loc,
plan->target.base,
plan->source.base,
newTargetOffset,
newSourceOffset,
static_cast<int32_t>(plan->directBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.replaceOp(copyOp, replacementValue);
return success();
}
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
Value c0 = createIndexConstant(rewriter, anchorOp, 0);
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
Value cStep = createIndexConstant(rewriter, anchorOp, 1);
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape);
Value loopTargetOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
Value loopSourceOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
auto newCopyOp = createCopyOp(loc,
plan->target.base,
plan->source.base,
loopTargetOffset,
loopSourceOffset,
static_cast<int32_t>(plan->loop.chunkBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.setInsertionPointAfter(loop);
rewriter.replaceOp(copyOp, replacementValue);
return success();
}
@@ -224,34 +364,24 @@ struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyO
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto status = rewriteSubviewCopyLikeOp(
return rewriteCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getTarget());
return success();
copyOp.getTarget(),
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
return pim::PimMemCopyOp::create(rewriter,
loc,
cast<MemRefType>(target.getType()),
targetOffset,
sourceOffset,
target,
source,
rewriter.getI32IntegerAttr(size));
},
rewriter);
}
};
@@ -259,38 +389,24 @@ struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyH
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
return rewriteCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
return success();
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getDeviceTarget(),
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
return pim::PimMemCopyHostToDevOp::create(rewriter,
loc,
cast<MemRefType>(target.getType()),
targetOffset,
sourceOffset,
target,
source,
rewriter.getI32IntegerAttr(size));
},
rewriter);
}
};
@@ -298,43 +414,43 @@ struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopy
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
return rewriteCopyLikeOp(
copyOp,
copyOp.getHostTarget(),
copyOp.getDeviceSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
return success();
copyOp.getHostTargetOffset(),
copyOp.getDeviceSourceOffset(),
copyOp.getHostTarget(),
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
return pim::PimMemCopyDevToHostOp::create(rewriter,
loc,
cast<MemRefType>(target.getType()),
targetOffset,
sourceOffset,
target,
source,
rewriter.getI32IntegerAttr(size));
},
rewriter);
}
};
} // namespace
bool isNormalizedCopyOp(pim::PimMemCopyOp op) {
return isNormalizedCopyLikeOp(op, op.getTarget(), op.getSource(), op.getTargetOffset(), op.getSourceOffset());
}
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op) {
return isNormalizedCopyLikeOp(
op, op.getDeviceTarget(), op.getHostSource(), op.getDeviceTargetOffset(), op.getHostSourceOffset());
}
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op) {
return isNormalizedCopyLikeOp(
op, op.getHostTarget(), op.getDeviceSource(), op.getHostTargetOffset(), op.getDeviceSourceOffset());
}
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
patterns.getContext());
@@ -2,8 +2,14 @@
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir::pim {
bool isNormalizedCopyOp(pim::PimMemCopyOp op);
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op);
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op);
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir::pim
@@ -101,10 +101,10 @@ struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInt
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
memCopyOp,
targetOpt->getType(),
memCopyOp.getTargetOffset(),
memCopyOp.getSourceOffset(),
*targetOpt,
*sourceOpt,
memCopyOp.getTargetOffsetAttr(),
memCopyOp.getSourceOffsetAttr(),
memCopyOp.getSizeAttr());
return success();
}
@@ -1,19 +0,0 @@
#ifndef PIM_BUFFERIZATION
#define PIM_BUFFERIZATION
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat<
(CopyOp $src, $dst),
(PimMemCopyOp $dst, $src,
ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">,
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION
@@ -6,7 +6,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Casting.h"
@@ -27,7 +27,34 @@ namespace onnx_mlir {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource());
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
sizeAttr);
rewriter.eraseOp(copyOp);
return success();
}
};
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
@@ -44,6 +71,14 @@ private:
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
};
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
if (!op || !op->getBlock())
return failure();
rewriter.setInsertionPoint(op);
return applicator.matchAndRewrite(op, rewriter);
}
} // namespace
void PimBufferizationPass::runOnOperation() {
@@ -63,35 +98,60 @@ void PimBufferizationPass::runOnOperation() {
}
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect>();
RewritePatternSet memrefCopyPatterns(ctx);
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
memrefCopyApplicator.applyDefaultCostModel();
PatternRewriter rewriter(ctx);
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
// Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
bool hasFailed = false;
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
return WalkResult::advance();
if (failed(applyPartialConversion(op, target, frozenPatterns)))
hasFailed = true;
return WalkResult::skip();
SmallVector<memref::CopyOp> copyWorklist;
moduleOp.walk([&](memref::CopyOp copyOp) {
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
copyWorklist.push_back(copyOp);
});
bool hasFailed = false;
for (memref::CopyOp copyOp : copyWorklist) {
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
hasFailed = true;
}
}
if (hasFailed) {
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
signalPassFailure();
return;
}
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
// Later PIM codegen passes may verify this invariant but must not repair it.
RewritePatternSet contiguityPatterns(ctx);
populatePimContiguityNormalizationPatterns(contiguityPatterns);
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
FrozenRewritePatternSet frozenContiguityPatterns(std::move(contiguityPatterns));
PatternApplicator contiguityApplicator(frozenContiguityPatterns);
contiguityApplicator.applyDefaultCostModel();
SmallVector<Operation*> contiguityWorklist;
moduleOp.walk([&](Operation* op) {
if (isa<pim::PimMemCopyOp, pim::PimMemCopyHostToDevOp, pim::PimMemCopyDevToHostOp>(op))
contiguityWorklist.push_back(op);
});
hasFailed = false;
for (Operation* op : contiguityWorklist) {
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
continue;
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
continue;
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
continue;
if (failed(applyPatternsOnce(op, contiguityApplicator, rewriter))) {
op->emitOpError("failed to normalize PIM copy contiguity");
hasFailed = true;
}
}
if (hasFailed) {
moduleOp.emitError("failed to normalize PIM copy contiguity during bufferization");
signalPassFailure();
return;
}
@@ -138,16 +198,28 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
};
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
if (!pim::isNormalizedCopyOp(memCopyOp)) {
memCopyOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
hasFailure = true;
}
verifyOperand(memCopyOp.getTarget(), 0);
verifyOperand(memCopyOp.getSource(), 1);
return;
}
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
if (!pim::isNormalizedCopyOp(loadOp)) {
loadOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
hasFailure = true;
}
verifyOperand(loadOp.getDeviceTarget(), 2);
verifyOperand(loadOp.getHostSource(), 3);
return;
}
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
if (!pim::isNormalizedCopyOp(storeOp)) {
storeOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
hasFailure = true;
}
verifyOperand(storeOp.getHostTarget(), 2);
verifyOperand(storeOp.getDeviceSource(), 3);
return;