compact memory contiguity with for loops
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-31 18:47:59 +02:00
parent ab63498f3f
commit b678e55d3c
14 changed files with 550 additions and 331 deletions
+22 -6
View File
@@ -69,6 +69,16 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
return mlir::failure();
return strides;
}
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
@@ -539,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
@@ -651,12 +663,16 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
constantByteOffset +=
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
@@ -665,7 +681,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
+2 -8
View File
@@ -5,8 +5,6 @@
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -15,16 +13,12 @@ namespace onnx_mlir {
Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock();
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
+8 -2
View File
@@ -235,6 +235,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
if (failed(compiledExpr)) {
errs() << "Failed to compile contiguous address for value: ";
value.print(errs());
errs() << " : " << value.getType();
errs() << "\n";
llvm_unreachable("Failed to compile contiguous address");
}
@@ -245,6 +246,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
if (failed(resolvedAddress)) {
errs() << "Failed to evaluate contiguous address for value: ";
value.print(errs());
errs() << " : " << value.getType();
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
@@ -493,11 +495,15 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const Static
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
auto targetOffset = indexOf(lmvOp.getTargetOffset(), knowledge);
auto sourceOffset = indexOf(lmvOp.getSourceOffset(), knowledge);
assert(succeeded(targetOffset) && succeeded(sourceOffset)
&& "pim.memcp offsets must be statically resolvable during codegen");
emitMemCopyOp("lmv",
addressOf(lmvOp.getTarget(), knowledge),
lmvOp.getTargetOffset(),
*targetOffset,
addressOf(lmvOp.getSource(), knowledge),
lmvOp.getSourceOffset(),
*sourceOffset,
lmvOp.getSize(),
"len");
}
@@ -101,9 +101,9 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
auto zeroAttr = rewriter.getI32IntegerAttr(0);
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput();
}
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
+5 -3
View File
@@ -176,10 +176,10 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region within the same memory space";
let arguments = (ins
Index:$targetOffset,
Index:$sourceOffset,
PimTensor:$target,
PimTensor:$source,
I32Attr:$targetOffset,
I32Attr:$sourceOffset,
I32Attr:$size
);
@@ -194,7 +194,9 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
`[` $targetOffset `,` $sourceOffset `]`
`(` $target `,` $source `)` attr-dict
`:` type($target) `,` type($source) `->` type($output)
}];
}
@@ -19,14 +19,15 @@ Value materializeContiguousInputMemRef(Value memrefValue, Location loc, Rewriter
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
zeroOffset,
zeroOffset,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
@@ -1,7 +1,3 @@
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_pim_library(OMPimBufferization
PimBufferizationPass.cpp
BufferizationUtils.hpp
@@ -15,9 +11,6 @@ add_pim_library(OMPimBufferization
EXCLUDE_FROM_OM_LIBS
DEPENDS
PimBufferizationIncGen
LINK_LIBS PUBLIC
OMPimCommon
PimOps
@@ -3,220 +3,360 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
struct ByteOffsetTerm {
Value value;
int64_t scale = 0;
};
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
struct ByteOffsetExpr {
int64_t constant = 0;
SmallVector<ByteOffsetTerm> terms;
};
struct CopyEndpointPlan {
Value base;
MemRefType originalType;
MemRefType baseType;
ByteOffsetExpr offset;
};
struct CopyLoopPlan {
SmallVector<int64_t> outerShape;
int64_t chunkBytes = 0;
ByteOffsetExpr targetBaseOffset;
ByteOffsetExpr sourceBaseOffset;
SmallVector<int64_t> targetOuterByteStrides;
SmallVector<int64_t> sourceOuterByteStrides;
};
struct CopyRewritePlan {
enum class Kind {
Direct,
Loop
} kind = Kind::Direct;
CopyEndpointPlan target;
CopyEndpointPlan source;
int64_t directBytes = 0;
CopyLoopPlan loop;
};
static bool isViewLike(Value value) {
Operation* defOp = value.getDefiningOp();
return defOp
&& isa<memref::SubViewOp,
memref::ReinterpretCastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
memref::CastOp>(defOp);
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
template <typename CopyOp>
static bool isNormalizedCopyLikeOp(CopyOp copyOp, Value target, Value source, Value targetOffset, Value sourceOffset) {
auto targetType = dyn_cast<MemRefType>(target.getType());
auto sourceType = dyn_cast<MemRefType>(source.getType());
return targetType && sourceType && !isViewLike(target) && !isViewLike(source) && targetOffset.getType().isIndex()
&& sourceOffset.getType().isIndex() && copyOp.getSize() > 0;
}
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
static void appendTerm(ByteOffsetExpr& expr, Value value, int64_t scale) {
if (scale != 0)
expr.terms.push_back(ByteOffsetTerm {value, scale});
}
static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return failure();
if (llvm::any_of(strides, ShapedType::isDynamic))
return failure();
return strides;
}
static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
return failure();
return static_cast<int64_t>(getShapedTypeSizeInBytes(type));
}
static FailureOr<SmallVector<int64_t>>
inferLogicalCopyShape(MemRefType targetType, MemRefType sourceType, int64_t size) {
if (!targetType.hasStaticShape() || !sourceType.hasStaticShape())
return failure();
if (targetType.getElementType() != sourceType.getElementType() || targetType.getRank() != sourceType.getRank())
return failure();
auto targetBytes = getShapedByteSize(targetType);
auto sourceBytes = getShapedByteSize(sourceType);
if (failed(targetBytes) || failed(sourceBytes))
return failure();
bool targetMatches = *targetBytes == size;
bool sourceMatches = *sourceBytes == size;
if (targetMatches && sourceMatches && targetType.getShape() != sourceType.getShape())
return failure();
if (targetMatches)
return SmallVector<int64_t>(targetType.getShape().begin(), targetType.getShape().end());
if (sourceMatches)
return SmallVector<int64_t>(sourceType.getShape().begin(), sourceType.getShape().end());
return failure();
}
static FailureOr<int64_t> getContiguousSuffixRank(MemRefType type, ArrayRef<int64_t> copyShape) {
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())
|| type.getRank() != static_cast<int64_t>(copyShape.size()))
return failure();
auto strides = getStaticMemRefStrides(type);
if (failed(strides))
return failure();
int64_t expectedStride = 1;
int64_t contiguousSuffixRank = 0;
for (int64_t dim = type.getRank() - 1; dim >= 0; --dim) {
if ((*strides)[dim] != expectedStride)
break;
++contiguousSuffixRank;
expectedStride *= copyShape[dim];
}
return contiguousSuffixRank;
}
static FailureOr<CopyEndpointPlan> analyzeCopyEndpoint(Value value, Value initialByteOffset, MemRefType logicalType) {
if (!logicalType.hasStaticShape() || !hasByteSizedElementType(logicalType.getElementType()))
return failure();
CopyEndpointPlan endpoint;
endpoint.base = value;
endpoint.originalType = logicalType;
appendTerm(endpoint.offset, initialByteOffset, 1);
while (true) {
if (auto castOp = endpoint.base.getDefiningOp<memref::CastOp>()) {
endpoint.base = castOp.getSource();
continue;
}
if (auto collapseOp = endpoint.base.getDefiningOp<memref::CollapseShapeOp>()) {
endpoint.base = collapseOp.getSrc();
continue;
}
if (auto expandOp = endpoint.base.getDefiningOp<memref::ExpandShapeOp>()) {
endpoint.base = expandOp.getSrc();
continue;
}
if (auto reinterpretOp = endpoint.base.getDefiningOp<memref::ReinterpretCastOp>()) {
endpoint.base = reinterpretOp.getSource();
continue;
}
auto value = cast<Value>(baseOffset);
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
}
auto subviewOp = endpoint.base.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
break;
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
if (!sourceType || !sourceType.hasStaticShape() || !hasByteSizedElementType(sourceType.getElementType()))
return failure();
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return failure();
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
for (auto [offset, stride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
int64_t byteScale = stride * elementByteWidth;
if (auto attr = dyn_cast<Attribute>(offset)) {
endpoint.offset.constant += cast<IntegerAttr>(attr).getInt() * byteScale;
continue;
}
appendTerm(endpoint.offset, cast<Value>(offset), byteScale);
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
endpoint.base = subviewOp.getSource();
}
endpoint.baseType = dyn_cast<MemRefType>(endpoint.base.getType());
if (!endpoint.baseType)
return failure();
return endpoint;
}
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
static FailureOr<CopyRewritePlan>
analyzeCopyRewrite(Value target, Value source, Value targetOffset, Value sourceOffset, int64_t size) {
auto targetType = dyn_cast<MemRefType>(target.getType());
auto sourceType = dyn_cast<MemRefType>(source.getType());
if (!targetType || !sourceType || size <= 0)
return failure();
auto logicalCopyShape = inferLogicalCopyShape(targetType, sourceType, size);
if (failed(logicalCopyShape))
return failure();
auto targetPlan = analyzeCopyEndpoint(target, targetOffset, targetType);
auto sourcePlan = analyzeCopyEndpoint(source, sourceOffset, sourceType);
if (failed(targetPlan) || failed(sourcePlan))
return failure();
auto targetSuffixRank = getContiguousSuffixRank(targetType, *logicalCopyShape);
auto sourceSuffixRank = getContiguousSuffixRank(sourceType, *logicalCopyShape);
if (failed(targetSuffixRank) || failed(sourceSuffixRank))
return failure();
CopyRewritePlan plan;
plan.target = *targetPlan;
plan.source = *sourcePlan;
int64_t contiguousSuffixRank = std::min(*targetSuffixRank, *sourceSuffixRank);
if (contiguousSuffixRank == static_cast<int64_t>(logicalCopyShape->size())) {
plan.kind = CopyRewritePlan::Kind::Direct;
plan.directBytes = size;
return plan;
}
auto targetStrides = getStaticMemRefStrides(targetType);
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(targetStrides) || failed(sourceStrides))
return failure();
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(targetType.getElementType()));
plan.kind = CopyRewritePlan::Kind::Loop;
plan.loop.targetBaseOffset = plan.target.offset;
plan.loop.sourceBaseOffset = plan.source.offset;
plan.loop.outerShape.assign(logicalCopyShape->begin(), logicalCopyShape->end() - contiguousSuffixRank);
SmallVector<int64_t> chunkShape(logicalCopyShape->end() - contiguousSuffixRank, logicalCopyShape->end());
plan.loop.chunkBytes = getNumElements(chunkShape) * elementByteWidth;
for (int64_t stride : ArrayRef<int64_t>(*targetStrides).take_front(plan.loop.outerShape.size()))
plan.loop.targetOuterByteStrides.push_back(stride * elementByteWidth);
for (int64_t stride : ArrayRef<int64_t>(*sourceStrides).take_front(plan.loop.outerShape.size()))
plan.loop.sourceOuterByteStrides.push_back(stride * elementByteWidth);
if (plan.loop.chunkBytes <= 0)
return failure();
return plan;
}
static Value createIndexConstant(PatternRewriter& rewriter, Operation* anchorOp, int64_t value) {
return getOrCreateIndexConstant(rewriter, anchorOp, value);
}
static Value addIndexValues(PatternRewriter& rewriter, Location loc, Value lhs, Value rhs) {
if (auto constant = getConstantIntValue(lhs); constant && *constant == 0)
return rhs;
if (auto constant = getConstantIntValue(rhs); constant && *constant == 0)
return lhs;
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
}
static Value mulIndexValue(PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value value, int64_t scale) {
if (scale == 0)
return createIndexConstant(rewriter, anchorOp, 0);
if (scale == 1)
return value;
Value scaleValue = createIndexConstant(rewriter, anchorOp, scale);
return arith::MulIOp::create(rewriter, loc, value, scaleValue).getResult();
}
static Value
materializeByteOffset(PatternRewriter& rewriter, Location loc, Operation* anchorOp, const ByteOffsetExpr& expr) {
Value result = createIndexConstant(rewriter, anchorOp, expr.constant);
for (const ByteOffsetTerm& term : expr.terms)
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, term.value, term.scale));
return result;
}
static SmallVector<Value> materializeDelinearizedIndices(
PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value linearIndex, ArrayRef<int64_t> shape) {
SmallVector<Value> indices;
indices.reserve(shape.size());
if (shape.empty())
return indices;
auto rowMajorStrides = computeRowMajorStrides(shape);
Value remaining = linearIndex;
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
for (auto [dim, stride] : llvm::enumerate(rowMajorStrides)) {
if (dim + 1 == rowMajorStrides.size()) {
indices.push_back(remaining);
break;
}
Value strideValue = createIndexConstant(rewriter, anchorOp, stride);
Value index = arith::DivUIOp::create(rewriter, loc, remaining, strideValue);
indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, loc, remaining, strideValue);
}
return indices;
}
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = cast<IntegerAttr>(attr);
if (integerAttr.getInt() == 0)
return extraOffset;
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
}
auto value = cast<Value>(baseOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
}
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<Value> outerIndices,
static Value materializeOuterByteOffset(PatternRewriter& rewriter,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
if (dim + 1 < info.sizes.size()) {
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1));
}
else {
chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
}
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static Value buildContiguousChunk(
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(copyShape.size());
chunkSizes.reserve(copyShape.size());
chunkStrides.reserve(copyShape.size());
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
chunkStrides.push_back(rewriter.getIndexAttr(1));
}
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
Operation* anchorOp,
const ByteOffsetExpr& baseOffset,
ArrayRef<Value> outerIndices,
ArrayRef<int64_t> outerByteStrides) {
Value result = materializeByteOffset(rewriter, loc, anchorOp, baseOffset);
for (auto [index, stride] : llvm::zip_equal(outerIndices, outerByteStrides))
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, index, stride));
return result;
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
Value target,
Value source,
Value targetOffset,
Value sourceOffset,
Value replacementValue,
CreateCopyOp createCopyOp,
PatternRewriter& rewriter) {
if (isNormalizedCopyLikeOp(copyOp, target, source, targetOffset, sourceOffset))
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
auto plan = analyzeCopyRewrite(target, source, targetOffset, sourceOffset, copyOp.getSize());
if (failed(plan))
return failure();
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
if (!hasByteSizedElementType(sourceType.getElementType()))
return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
&& (splitDst || dstShapeMatchesCopyShape)) {
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
Location loc = copyOp.getLoc();
Operation* anchorOp = copyOp.getOperation();
if (plan->kind == CopyRewritePlan::Kind::Direct) {
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
auto newCopyOp = createCopyOp(loc,
plan->target.base,
plan->source.base,
newTargetOffset,
newSourceOffset,
static_cast<int32_t>(plan->directBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.replaceOp(copyOp, replacementValue);
return success();
}
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
Value c0 = createIndexConstant(rewriter, anchorOp, 0);
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
Value cStep = createIndexConstant(rewriter, anchorOp, 1);
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape);
Value loopTargetOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
Value loopSourceOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
auto newCopyOp = createCopyOp(loc,
plan->target.base,
plan->source.base,
loopTargetOffset,
loopSourceOffset,
static_cast<int32_t>(plan->loop.chunkBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.setInsertionPointAfter(loop);
rewriter.replaceOp(copyOp, replacementValue);
return success();
}
@@ -224,34 +364,24 @@ struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyO
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto status = rewriteSubviewCopyLikeOp(
return rewriteCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getTarget());
return success();
copyOp.getTarget(),
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
return pim::PimMemCopyOp::create(rewriter,
loc,
cast<MemRefType>(target.getType()),
targetOffset,
sourceOffset,
target,
source,
rewriter.getI32IntegerAttr(size));
},
rewriter);
}
};
@@ -259,38 +389,24 @@ struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyH
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
return rewriteCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
return success();
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getDeviceTarget(),
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
return pim::PimMemCopyHostToDevOp::create(rewriter,
loc,
cast<MemRefType>(target.getType()),
targetOffset,
sourceOffset,
target,
source,
rewriter.getI32IntegerAttr(size));
},
rewriter);
}
};
@@ -298,43 +414,43 @@ struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopy
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
return rewriteCopyLikeOp(
copyOp,
copyOp.getHostTarget(),
copyOp.getDeviceSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
return success();
copyOp.getHostTargetOffset(),
copyOp.getDeviceSourceOffset(),
copyOp.getHostTarget(),
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
return pim::PimMemCopyDevToHostOp::create(rewriter,
loc,
cast<MemRefType>(target.getType()),
targetOffset,
sourceOffset,
target,
source,
rewriter.getI32IntegerAttr(size));
},
rewriter);
}
};
} // namespace
bool isNormalizedCopyOp(pim::PimMemCopyOp op) {
return isNormalizedCopyLikeOp(op, op.getTarget(), op.getSource(), op.getTargetOffset(), op.getSourceOffset());
}
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op) {
return isNormalizedCopyLikeOp(
op, op.getDeviceTarget(), op.getHostSource(), op.getDeviceTargetOffset(), op.getHostSourceOffset());
}
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op) {
return isNormalizedCopyLikeOp(
op, op.getHostTarget(), op.getDeviceSource(), op.getHostTargetOffset(), op.getDeviceSourceOffset());
}
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
patterns.getContext());
@@ -2,8 +2,14 @@
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir::pim {
bool isNormalizedCopyOp(pim::PimMemCopyOp op);
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op);
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op);
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir::pim
@@ -101,10 +101,10 @@ struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInt
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
memCopyOp,
targetOpt->getType(),
memCopyOp.getTargetOffset(),
memCopyOp.getSourceOffset(),
*targetOpt,
*sourceOpt,
memCopyOp.getTargetOffsetAttr(),
memCopyOp.getSourceOffsetAttr(),
memCopyOp.getSizeAttr());
return success();
}
@@ -1,19 +0,0 @@
#ifndef PIM_BUFFERIZATION
#define PIM_BUFFERIZATION
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat<
(CopyOp $src, $dst),
(PimMemCopyOp $dst, $src,
ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">,
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION
@@ -6,7 +6,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Casting.h"
@@ -27,7 +27,34 @@ namespace onnx_mlir {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource());
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
sizeAttr);
rewriter.eraseOp(copyOp);
return success();
}
};
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
@@ -44,6 +71,14 @@ private:
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
};
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
if (!op || !op->getBlock())
return failure();
rewriter.setInsertionPoint(op);
return applicator.matchAndRewrite(op, rewriter);
}
} // namespace
void PimBufferizationPass::runOnOperation() {
@@ -63,35 +98,60 @@ void PimBufferizationPass::runOnOperation() {
}
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect>();
RewritePatternSet memrefCopyPatterns(ctx);
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
memrefCopyApplicator.applyDefaultCostModel();
PatternRewriter rewriter(ctx);
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
// Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
bool hasFailed = false;
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
return WalkResult::advance();
if (failed(applyPartialConversion(op, target, frozenPatterns)))
hasFailed = true;
return WalkResult::skip();
SmallVector<memref::CopyOp> copyWorklist;
moduleOp.walk([&](memref::CopyOp copyOp) {
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
copyWorklist.push_back(copyOp);
});
bool hasFailed = false;
for (memref::CopyOp copyOp : copyWorklist) {
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
hasFailed = true;
}
}
if (hasFailed) {
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
signalPassFailure();
return;
}
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
// Later PIM codegen passes may verify this invariant but must not repair it.
RewritePatternSet contiguityPatterns(ctx);
populatePimContiguityNormalizationPatterns(contiguityPatterns);
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
FrozenRewritePatternSet frozenContiguityPatterns(std::move(contiguityPatterns));
PatternApplicator contiguityApplicator(frozenContiguityPatterns);
contiguityApplicator.applyDefaultCostModel();
SmallVector<Operation*> contiguityWorklist;
moduleOp.walk([&](Operation* op) {
if (isa<pim::PimMemCopyOp, pim::PimMemCopyHostToDevOp, pim::PimMemCopyDevToHostOp>(op))
contiguityWorklist.push_back(op);
});
hasFailed = false;
for (Operation* op : contiguityWorklist) {
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
continue;
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
continue;
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
continue;
if (failed(applyPatternsOnce(op, contiguityApplicator, rewriter))) {
op->emitOpError("failed to normalize PIM copy contiguity");
hasFailed = true;
}
}
if (hasFailed) {
moduleOp.emitError("failed to normalize PIM copy contiguity during bufferization");
signalPassFailure();
return;
}
@@ -138,16 +198,28 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
};
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
if (!pim::isNormalizedCopyOp(memCopyOp)) {
memCopyOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
hasFailure = true;
}
verifyOperand(memCopyOp.getTarget(), 0);
verifyOperand(memCopyOp.getSource(), 1);
return;
}
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
if (!pim::isNormalizedCopyOp(loadOp)) {
loadOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
hasFailure = true;
}
verifyOperand(loadOp.getDeviceTarget(), 2);
verifyOperand(loadOp.getHostSource(), 3);
return;
}
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
if (!pim::isNormalizedCopyOp(storeOp)) {
storeOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
hasFailure = true;
}
verifyOperand(storeOp.getHostTarget(), 2);
verifyOperand(storeOp.getDeviceSource(), 3);
return;
@@ -121,13 +121,14 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = getShapedTypeSizeInBytes(initType);
Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0);
pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(),
initType,
zeroOffset,
zeroOffset,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(mapOp);
return success();
@@ -487,7 +488,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
if (!allocType || !allocType.hasStaticShape())
return failure();
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
auto targetOffset = resolveIndexValue(copyOp.getTargetOffset());
auto sourceOffset = resolveIndexValue(copyOp.getSourceOffset());
if (failed(targetOffset) || failed(sourceOffset) || *targetOffset != 0 || *sourceOffset != 0)
return failure();
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
@@ -13,6 +13,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -333,6 +334,12 @@ private:
}
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
if (!pim::isNormalizedCopyOp(storeOp)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
});
hasFailure = true;
}
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
@@ -343,6 +350,12 @@ private:
}
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
if (!pim::isNormalizedCopyOp(loadOp)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
});
hasFailure = true;
}
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
@@ -351,6 +364,22 @@ private:
hasFailure = true;
}
}
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op)) {
if (!pim::isNormalizedCopyOp(copyOp)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
});
hasFailure = true;
}
if (failed(resolveIndexValue(copyOp.getTargetOffset(), knowledge))
|| failed(resolveIndexValue(copyOp.getSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
});
hasFailure = true;
}
}
return success(!hasFailure);
});
}