From b678e55d3c2c87c84bd9265c4884bb867f05c4ac Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sun, 31 May 2026 18:47:59 +0200 Subject: [PATCH] compact memory contiguity with for loops --- src/PIM/Common/IR/AddressAnalysis.cpp | 28 +- src/PIM/Common/IR/ConstantUtils.cpp | 10 +- src/PIM/Compiler/PimCodeGen.cpp | 10 +- .../SpatialToPim/SpatialToPimPass.cpp | 4 +- src/PIM/Dialect/Pim/Pim.td | 8 +- .../Bufferization/BufferizationUtils.cpp | 5 +- .../Transforms/Bufferization/CMakeLists.txt | 7 - .../Bufferization/ContiguityPatterns.cpp | 626 +++++++++++------- .../Bufferization/ContiguityPatterns.hpp | 6 + .../OpBufferizationInterfaces.cpp | 4 +- .../Bufferization/PimBufferization.td | 19 - .../Bufferization/PimBufferizationPass.cpp | 116 +++- .../HostConstantFolding/Patterns/Constant.cpp | 9 +- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 29 + 14 files changed, 550 insertions(+), 331 deletions(-) delete mode 100644 src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index f477cb5..a46e361 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -69,6 +69,16 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge); +static llvm::FailureOr> getStaticMemRefStrides(mlir::MemRefType type) { + llvm::SmallVector strides; + int64_t offset = 0; + if (failed(type.getStridesAndOffset(strides, offset))) + return mlir::failure(); + if (llvm::any_of(strides, mlir::ShapedType::isDynamic)) + return mlir::failure(); + return strides; +} + static llvm::FailureOr resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp, const StaticValueKnowledge* knowledge) { auto getGlobalOp = loadOp.getMemRef().getDefiningOp(); @@ -539,8 +549,10 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) return mlir::failure(); - auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); - byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); + auto sourceStrides = getStaticMemRefStrides(sourceType); + if (failed(sourceStrides)) + return mlir::failure(); + byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); value = resolveAlias(subviewOp.getSource(), knowledge); continue; } @@ -651,12 +663,16 @@ llvm::FailureOr compileContiguousAddressExprImpl(mlir::Valu if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides)) return mlir::failure(); - auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + auto sourceStrides = getStaticMemRefStrides(sourceType); + if (failed(sourceStrides)) + return mlir::failure(); constantByteOffset += - linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); + linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); } else { - llvm::SmallVector sourceStrides = computeRowMajorStrides(sourceType.getShape()); + auto sourceStrides = getStaticMemRefStrides(sourceType); + if (failed(sourceStrides)) + return mlir::failure(); CompiledIndexExpr offsetExpr; { CompiledIndexExprNode expr; @@ -665,7 +681,7 @@ llvm::FailureOr compileContiguousAddressExprImpl(mlir::Valu offsetExpr = makeCompiledIndexExpr(std::move(expr)); } - for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) { + for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) { CompiledIndexExpr operandExpr; if (auto attr = mlir::dyn_cast(mixedOffset)) { CompiledIndexExprNode expr; diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index fde70f3..fe21409 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -5,8 +5,6 @@ #include "mlir/IR/Dialect.h" #include "ConstantUtils.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -15,16 +13,12 @@ namespace onnx_mlir { Block* getConstantInsertionBlock(Operation* anchorOp) { assert(anchorOp && "expected a valid anchor operation"); - for (Operation* current = anchorOp; current; current = current->getParentOp()) - if (isa(current)) - return current->getBlock(); - if (auto funcOp = dyn_cast(anchorOp)) return &funcOp.getBody().front(); + if (auto funcOp = anchorOp->getParentOfType()) + return &funcOp.getBody().front(); if (auto moduleOp = dyn_cast(anchorOp)) return moduleOp.getBody(); - if (auto funcOp = anchorOp->getParentOfType()) - return &funcOp.getBody().front(); if (auto moduleOp = anchorOp->getParentOfType()) return moduleOp.getBody(); return anchorOp->getBlock(); diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 60f2e8f..89d5c01 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -235,6 +235,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, if (failed(compiledExpr)) { errs() << "Failed to compile contiguous address for value: "; value.print(errs()); + errs() << " : " << value.getType(); errs() << "\n"; llvm_unreachable("Failed to compile contiguous address"); } @@ -245,6 +246,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, if (failed(resolvedAddress)) { errs() << "Failed to evaluate contiguous address for value: "; value.print(errs()); + errs() << " : " << value.getType(); errs() << "\n"; if (auto* definingOp = value.getDefiningOp()) { errs() << "Defining op:\n"; @@ -493,11 +495,15 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const Static } void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const { + auto targetOffset = indexOf(lmvOp.getTargetOffset(), knowledge); + auto sourceOffset = indexOf(lmvOp.getSourceOffset(), knowledge); + assert(succeeded(targetOffset) && succeeded(sourceOffset) + && "pim.memcp offsets must be statically resolvable during codegen"); emitMemCopyOp("lmv", addressOf(lmvOp.getTarget(), knowledge), - lmvOp.getTargetOffset(), + *targetOffset, addressOf(lmvOp.getSource(), knowledge), - lmvOp.getSourceOffset(), + *sourceOffset, lmvOp.getSize(), "len"); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 4eb11aa..b3a85aa 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -101,9 +101,9 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, auto paddedType = RankedTensorType::get( {shape[0], static_cast(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); - auto zeroAttr = rewriter.getI32IntegerAttr(0); + Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(vectorType))); - return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); + return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput(); } void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 76a61b8..7db445a 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -176,10 +176,10 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> { let summary = "Copy a memory region within the same memory space"; let arguments = (ins + Index:$targetOffset, + Index:$sourceOffset, PimTensor:$target, PimTensor:$source, - I32Attr:$targetOffset, - I32Attr:$sourceOffset, I32Attr:$size ); @@ -194,7 +194,9 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> { }]; let assemblyFormat = [{ - `(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output) + `[` $targetOffset `,` $sourceOffset `]` + `(` $target `,` $source `)` attr-dict + `:` type($target) `,` type($source) `->` type($output) }]; } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index baef6c2..cb14469 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -19,14 +19,15 @@ Value materializeContiguousInputMemRef(Value memrefValue, Location loc, Rewriter auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); auto sizeInBytes = getShapedTypeSizeInBytes(shapedType); + Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0); return PimMemCopyOp::create(rewriter, loc, contiguousType, + zeroOffset, + zeroOffset, contiguousBuffer, memrefValue, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(sizeInBytes)) .getOutput(); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt index 62c90c8..73760cc 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt @@ -1,7 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS PimBufferization.td) -mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") -add_public_tablegen_target(PimBufferizationIncGen) - add_pim_library(OMPimBufferization PimBufferizationPass.cpp BufferizationUtils.hpp @@ -15,9 +11,6 @@ add_pim_library(OMPimBufferization EXCLUDE_FROM_OM_LIBS - DEPENDS - PimBufferizationIncGen - LINK_LIBS PUBLIC OMPimCommon PimOps diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp index e2937a1..9e00bde 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp @@ -3,220 +3,360 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "ContiguityPatterns.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir::pim { namespace { -static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) { - if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; })) - return false; +struct ByteOffsetTerm { + Value value; + int64_t scale = 0; +}; - return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides); +struct ByteOffsetExpr { + int64_t constant = 0; + SmallVector terms; +}; + +struct CopyEndpointPlan { + Value base; + MemRefType originalType; + MemRefType baseType; + ByteOffsetExpr offset; +}; + +struct CopyLoopPlan { + SmallVector outerShape; + int64_t chunkBytes = 0; + ByteOffsetExpr targetBaseOffset; + ByteOffsetExpr sourceBaseOffset; + SmallVector targetOuterByteStrides; + SmallVector sourceOuterByteStrides; +}; + +struct CopyRewritePlan { + enum class Kind { + Direct, + Loop + } kind = Kind::Direct; + CopyEndpointPlan target; + CopyEndpointPlan source; + int64_t directBytes = 0; + CopyLoopPlan loop; +}; + +static bool isViewLike(Value value) { + Operation* defOp = value.getDefiningOp(); + return defOp + && isa(defOp); } -static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) { - if (extraOffset == 0) - return baseOffset; +template +static bool isNormalizedCopyLikeOp(CopyOp copyOp, Value target, Value source, Value targetOffset, Value sourceOffset) { + auto targetType = dyn_cast(target.getType()); + auto sourceType = dyn_cast(source.getType()); + return targetType && sourceType && !isViewLike(target) && !isViewLike(source) && targetOffset.getType().isIndex() + && sourceOffset.getType().isIndex() && copyOp.getSize() > 0; +} - if (auto attr = dyn_cast(baseOffset)) { - auto integerAttr = dyn_cast(attr); - assert(integerAttr && "expected integer offset attribute"); - return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset); +static void appendTerm(ByteOffsetExpr& expr, Value value, int64_t scale) { + if (scale != 0) + expr.terms.push_back(ByteOffsetTerm {value, scale}); +} + +static FailureOr> getStaticMemRefStrides(MemRefType type) { + SmallVector strides; + int64_t offset = 0; + if (failed(type.getStridesAndOffset(strides, offset))) + return failure(); + if (llvm::any_of(strides, ShapedType::isDynamic)) + return failure(); + return strides; +} + +static FailureOr getShapedByteSize(MemRefType type) { + if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())) + return failure(); + return static_cast(getShapedTypeSizeInBytes(type)); +} + +static FailureOr> +inferLogicalCopyShape(MemRefType targetType, MemRefType sourceType, int64_t size) { + if (!targetType.hasStaticShape() || !sourceType.hasStaticShape()) + return failure(); + if (targetType.getElementType() != sourceType.getElementType() || targetType.getRank() != sourceType.getRank()) + return failure(); + + auto targetBytes = getShapedByteSize(targetType); + auto sourceBytes = getShapedByteSize(sourceType); + if (failed(targetBytes) || failed(sourceBytes)) + return failure(); + + bool targetMatches = *targetBytes == size; + bool sourceMatches = *sourceBytes == size; + if (targetMatches && sourceMatches && targetType.getShape() != sourceType.getShape()) + return failure(); + if (targetMatches) + return SmallVector(targetType.getShape().begin(), targetType.getShape().end()); + if (sourceMatches) + return SmallVector(sourceType.getShape().begin(), sourceType.getShape().end()); + return failure(); +} + +static FailureOr getContiguousSuffixRank(MemRefType type, ArrayRef copyShape) { + if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()) + || type.getRank() != static_cast(copyShape.size())) + return failure(); + + auto strides = getStaticMemRefStrides(type); + if (failed(strides)) + return failure(); + + int64_t expectedStride = 1; + int64_t contiguousSuffixRank = 0; + for (int64_t dim = type.getRank() - 1; dim >= 0; --dim) { + if ((*strides)[dim] != expectedStride) + break; + ++contiguousSuffixRank; + expectedStride *= copyShape[dim]; + } + return contiguousSuffixRank; +} + +static FailureOr analyzeCopyEndpoint(Value value, Value initialByteOffset, MemRefType logicalType) { + if (!logicalType.hasStaticShape() || !hasByteSizedElementType(logicalType.getElementType())) + return failure(); + + CopyEndpointPlan endpoint; + endpoint.base = value; + endpoint.originalType = logicalType; + appendTerm(endpoint.offset, initialByteOffset, 1); + + while (true) { + if (auto castOp = endpoint.base.getDefiningOp()) { + endpoint.base = castOp.getSource(); + continue; + } + if (auto collapseOp = endpoint.base.getDefiningOp()) { + endpoint.base = collapseOp.getSrc(); + continue; + } + if (auto expandOp = endpoint.base.getDefiningOp()) { + endpoint.base = expandOp.getSrc(); + continue; + } + if (auto reinterpretOp = endpoint.base.getDefiningOp()) { + endpoint.base = reinterpretOp.getSource(); + continue; + } + + auto subviewOp = endpoint.base.getDefiningOp(); + if (!subviewOp) + break; + + auto sourceType = dyn_cast(subviewOp.getSource().getType()); + if (!sourceType || !sourceType.hasStaticShape() || !hasByteSizedElementType(sourceType.getElementType())) + return failure(); + + auto sourceStrides = getStaticMemRefStrides(sourceType); + if (failed(sourceStrides)) + return failure(); + + int64_t elementByteWidth = static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); + for (auto [offset, stride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) { + int64_t byteScale = stride * elementByteWidth; + if (auto attr = dyn_cast(offset)) { + endpoint.offset.constant += cast(attr).getInt() * byteScale; + continue; + } + appendTerm(endpoint.offset, cast(offset), byteScale); + } + + endpoint.base = subviewOp.getSource(); } - auto value = cast(baseOffset); - auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset); - return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); + endpoint.baseType = dyn_cast(endpoint.base.getType()); + if (!endpoint.baseType) + return failure(); + return endpoint; } -static Value buildSubviewChunk(const StaticSubviewInfo& info, - ArrayRef outerIndices, - Location loc, - PatternRewriter& rewriter) { - SmallVector chunkOffsets; - SmallVector chunkSizes; - SmallVector chunkStrides; - chunkOffsets.reserve(info.offsets.size()); - chunkSizes.reserve(info.sizes.size()); - chunkStrides.reserve(info.strides.size()); +static FailureOr +analyzeCopyRewrite(Value target, Value source, Value targetOffset, Value sourceOffset, int64_t size) { + auto targetType = dyn_cast(target.getType()); + auto sourceType = dyn_cast(source.getType()); + if (!targetType || !sourceType || size <= 0) + return failure(); - for (size_t dim = 0; dim < info.sizes.size(); ++dim) { - int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0; - chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter)); - chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back())); - chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); + auto logicalCopyShape = inferLogicalCopyShape(targetType, sourceType, size); + if (failed(logicalCopyShape)) + return failure(); + + auto targetPlan = analyzeCopyEndpoint(target, targetOffset, targetType); + auto sourcePlan = analyzeCopyEndpoint(source, sourceOffset, sourceType); + if (failed(targetPlan) || failed(sourcePlan)) + return failure(); + + auto targetSuffixRank = getContiguousSuffixRank(targetType, *logicalCopyShape); + auto sourceSuffixRank = getContiguousSuffixRank(sourceType, *logicalCopyShape); + if (failed(targetSuffixRank) || failed(sourceSuffixRank)) + return failure(); + + CopyRewritePlan plan; + plan.target = *targetPlan; + plan.source = *sourcePlan; + + int64_t contiguousSuffixRank = std::min(*targetSuffixRank, *sourceSuffixRank); + if (contiguousSuffixRank == static_cast(logicalCopyShape->size())) { + plan.kind = CopyRewritePlan::Kind::Direct; + plan.directBytes = size; + return plan; } - return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); + auto targetStrides = getStaticMemRefStrides(targetType); + auto sourceStrides = getStaticMemRefStrides(sourceType); + if (failed(targetStrides) || failed(sourceStrides)) + return failure(); + + int64_t elementByteWidth = static_cast(getElementTypeSizeInBytes(targetType.getElementType())); + plan.kind = CopyRewritePlan::Kind::Loop; + plan.loop.targetBaseOffset = plan.target.offset; + plan.loop.sourceBaseOffset = plan.source.offset; + plan.loop.outerShape.assign(logicalCopyShape->begin(), logicalCopyShape->end() - contiguousSuffixRank); + SmallVector chunkShape(logicalCopyShape->end() - contiguousSuffixRank, logicalCopyShape->end()); + plan.loop.chunkBytes = getNumElements(chunkShape) * elementByteWidth; + for (int64_t stride : ArrayRef(*targetStrides).take_front(plan.loop.outerShape.size())) + plan.loop.targetOuterByteStrides.push_back(stride * elementByteWidth); + for (int64_t stride : ArrayRef(*sourceStrides).take_front(plan.loop.outerShape.size())) + plan.loop.sourceOuterByteStrides.push_back(stride * elementByteWidth); + if (plan.loop.chunkBytes <= 0) + return failure(); + return plan; } -static SmallVector delinearizeIndexValue(Value linearIndex, - ArrayRef shape, - ArrayRef strides, - PatternRewriter& rewriter) { +static Value createIndexConstant(PatternRewriter& rewriter, Operation* anchorOp, int64_t value) { + return getOrCreateIndexConstant(rewriter, anchorOp, value); +} + +static Value addIndexValues(PatternRewriter& rewriter, Location loc, Value lhs, Value rhs) { + if (auto constant = getConstantIntValue(lhs); constant && *constant == 0) + return rhs; + if (auto constant = getConstantIntValue(rhs); constant && *constant == 0) + return lhs; + return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult(); +} + +static Value mulIndexValue(PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value value, int64_t scale) { + if (scale == 0) + return createIndexConstant(rewriter, anchorOp, 0); + if (scale == 1) + return value; + Value scaleValue = createIndexConstant(rewriter, anchorOp, scale); + return arith::MulIOp::create(rewriter, loc, value, scaleValue).getResult(); +} + +static Value +materializeByteOffset(PatternRewriter& rewriter, Location loc, Operation* anchorOp, const ByteOffsetExpr& expr) { + Value result = createIndexConstant(rewriter, anchorOp, expr.constant); + for (const ByteOffsetTerm& term : expr.terms) + result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, term.value, term.scale)); + return result; +} + +static SmallVector materializeDelinearizedIndices( + PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value linearIndex, ArrayRef shape) { SmallVector indices; - indices.reserve(shape.size()); + if (shape.empty()) + return indices; + auto rowMajorStrides = computeRowMajorStrides(shape); Value remaining = linearIndex; - for (auto [_dim, stride] : llvm::enumerate(strides)) { - auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride); - Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); + for (auto [dim, stride] : llvm::enumerate(rowMajorStrides)) { + if (dim + 1 == rowMajorStrides.size()) { + indices.push_back(remaining); + break; + } + Value strideValue = createIndexConstant(rewriter, anchorOp, stride); + Value index = arith::DivUIOp::create(rewriter, loc, remaining, strideValue); indices.push_back(index); - remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); + remaining = arith::RemUIOp::create(rewriter, loc, remaining, strideValue); } - return indices; } -static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) { - if (auto attr = dyn_cast(baseOffset)) { - auto integerAttr = cast(attr); - if (integerAttr.getInt() == 0) - return extraOffset; - - auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt()); - return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); - } - - auto value = cast(baseOffset); - return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult(); -} - -static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info, - ArrayRef outerIndices, - Location loc, - PatternRewriter& rewriter) { - SmallVector chunkOffsets; - SmallVector chunkSizes; - SmallVector chunkStrides; - chunkOffsets.reserve(info.offsets.size()); - chunkSizes.reserve(info.sizes.size()); - chunkStrides.reserve(info.strides.size()); - - for (size_t dim = 0; dim < info.sizes.size(); ++dim) { - if (dim + 1 < info.sizes.size()) { - assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); - chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); - chunkSizes.push_back(rewriter.getIndexAttr(1)); - } - else { - chunkOffsets.push_back(info.offsets[dim]); - chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); - } - chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); - } - - return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); -} - -static Value buildContiguousChunk( - Value source, ArrayRef copyShape, ArrayRef outerIndices, Location loc, PatternRewriter& rewriter) { - SmallVector chunkOffsets; - SmallVector chunkSizes; - SmallVector chunkStrides; - chunkOffsets.reserve(copyShape.size()); - chunkSizes.reserve(copyShape.size()); - chunkStrides.reserve(copyShape.size()); - - for (size_t dim = 0; dim < copyShape.size(); ++dim) { - chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0)); - chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back())); - chunkStrides.push_back(rewriter.getIndexAttr(1)); - } - - return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides); +static Value materializeOuterByteOffset(PatternRewriter& rewriter, + Location loc, + Operation* anchorOp, + const ByteOffsetExpr& baseOffset, + ArrayRef outerIndices, + ArrayRef outerByteStrides) { + Value result = materializeByteOffset(rewriter, loc, anchorOp, baseOffset); + for (auto [index, stride] : llvm::zip_equal(outerIndices, outerByteStrides)) + result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, index, stride)); + return result; } template -static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, - Value dst, - Value src, - int64_t dstOffset, - int64_t srcOffset, - int64_t size, - bool allowLoopRewrite, - PatternRewriter& rewriter, - CreateCopyOp createCopyOp) { - auto srcSubview = getStaticSubviewInfo(src); - auto dstSubview = getStaticSubviewInfo(dst); - const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview); - const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview); - if (!splitSrc && !splitDst) +static LogicalResult rewriteCopyLikeOp(CopyOp copyOp, + Value target, + Value source, + Value targetOffset, + Value sourceOffset, + Value replacementValue, + CreateCopyOp createCopyOp, + PatternRewriter& rewriter) { + if (isNormalizedCopyLikeOp(copyOp, target, source, targetOffset, sourceOffset)) return failure(); - auto sourceType = dyn_cast(src.getType()); - auto dstType = dyn_cast(dst.getType()); - if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) - return failure(); - if (sourceType.getElementType() != dstType.getElementType()) + auto plan = analyzeCopyRewrite(target, source, targetOffset, sourceOffset, copyOp.getSize()); + if (failed(plan)) return failure(); - if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))) - return failure(); - if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))) - return failure(); - - ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); - if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) - return failure(); - - if (!hasByteSizedElementType(sourceType.getElementType())) - return failure(); - const int64_t elementByteWidth = static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); - - const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; - if (size != totalBytes) - return failure(); - - const int64_t sliceBytes = copyShape.back() * elementByteWidth; - if (sliceBytes <= 0) - return failure(); - - SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); - auto outerStrides = computeRowMajorStrides(outerShape); - const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); - const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape); - const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape); - - if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0 - && sourceType.getRank() == static_cast(copyShape.size()) - && dstType.getRank() == static_cast(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape) - && (splitDst || dstShapeMatchesCopyShape)) { - auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0); - auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices); - auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1); - - auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); - rewriter.setInsertionPointToStart(loop.getBody()); - - SmallVector outerIndices = - outerShape.empty() ? SmallVector {} - : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); - Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) - : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); - Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) - : buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter); - createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes); + Location loc = copyOp.getLoc(); + Operation* anchorOp = copyOp.getOperation(); + if (plan->kind == CopyRewritePlan::Kind::Direct) { + Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset); + Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset); + auto newCopyOp = createCopyOp(loc, + plan->target.base, + plan->source.base, + newTargetOffset, + newSourceOffset, + static_cast(plan->directBytes)); + assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); + rewriter.replaceOp(copyOp, replacementValue); return success(); } - rewriter.setInsertionPoint(copyOp); - for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { - SmallVector outerIndices = - outerShape.empty() ? SmallVector {} : delinearizeIndex(linearIndex, outerShape, outerStrides); - Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst; - Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src; - const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes; - const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes; - createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes); - } + Value c0 = createIndexConstant(rewriter, anchorOp, 0); + Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape)); + Value cStep = createIndexConstant(rewriter, anchorOp, 1); + auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {}); + rewriter.setInsertionPointToStart(loop.getBody()); + SmallVector outerIndices = + materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape); + Value loopTargetOffset = materializeOuterByteOffset( + rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides); + Value loopSourceOffset = materializeOuterByteOffset( + rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides); + auto newCopyOp = createCopyOp(loc, + plan->target.base, + plan->source.base, + loopTargetOffset, + loopSourceOffset, + static_cast(plan->loop.chunkBytes)); + assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); + rewriter.setInsertionPointAfter(loop); + rewriter.replaceOp(copyOp, replacementValue); return success(); } @@ -224,34 +364,24 @@ struct NormalizeCoreSubviewCopyPattern final : OpRewritePatterngetParentOfType() && !copyOp->getParentOfType()) - return failure(); - - auto status = rewriteSubviewCopyLikeOp( + return rewriteCopyLikeOp( copyOp, copyOp.getTarget(), copyOp.getSource(), copyOp.getTargetOffset(), copyOp.getSourceOffset(), - copyOp.getSize(), - /*allowLoopRewrite=*/true, - rewriter, - [&]( - MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { - pim::PimMemCopyOp::create(rewriter, - copyOp.getLoc(), - resultType, - dst, - src, - rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - }); - if (failed(status)) - return failure(); - - rewriter.replaceOp(copyOp, copyOp.getTarget()); - return success(); + copyOp.getTarget(), + [&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) { + return pim::PimMemCopyOp::create(rewriter, + loc, + cast(target.getType()), + targetOffset, + sourceOffset, + target, + source, + rewriter.getI32IntegerAttr(size)); + }, + rewriter); } }; @@ -259,38 +389,24 @@ struct NormalizeHostSubviewLoadPattern final : OpRewritePattern(sliceBytes))); - }); - if (failed(status)) - return failure(); - - rewriter.replaceOp(copyOp, copyOp.getDeviceTarget()); - return success(); + copyOp.getDeviceTargetOffset(), + copyOp.getHostSourceOffset(), + copyOp.getDeviceTarget(), + [&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) { + return pim::PimMemCopyHostToDevOp::create(rewriter, + loc, + cast(target.getType()), + targetOffset, + sourceOffset, + target, + source, + rewriter.getI32IntegerAttr(size)); + }, + rewriter); } }; @@ -298,43 +414,43 @@ struct NormalizeHostSubviewStorePattern final : OpRewritePattern(sliceBytes))); - }); - if (failed(status)) - return failure(); - - rewriter.replaceOp(copyOp, copyOp.getHostTarget()); - return success(); + copyOp.getHostTargetOffset(), + copyOp.getDeviceSourceOffset(), + copyOp.getHostTarget(), + [&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) { + return pim::PimMemCopyDevToHostOp::create(rewriter, + loc, + cast(target.getType()), + targetOffset, + sourceOffset, + target, + source, + rewriter.getI32IntegerAttr(size)); + }, + rewriter); } }; } // namespace +bool isNormalizedCopyOp(pim::PimMemCopyOp op) { + return isNormalizedCopyLikeOp(op, op.getTarget(), op.getSource(), op.getTargetOffset(), op.getSourceOffset()); +} + +bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op) { + return isNormalizedCopyLikeOp( + op, op.getDeviceTarget(), op.getHostSource(), op.getDeviceTargetOffset(), op.getHostSourceOffset()); +} + +bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op) { + return isNormalizedCopyLikeOp( + op, op.getHostTarget(), op.getDeviceSource(), op.getHostTargetOffset(), op.getDeviceSourceOffset()); +} + void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) { patterns.add( patterns.getContext()); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp index c47c56c..56f1b01 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp @@ -2,8 +2,14 @@ #include "mlir/IR/PatternMatch.h" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + namespace onnx_mlir::pim { +bool isNormalizedCopyOp(pim::PimMemCopyOp op); +bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op); +bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op); + void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns); } // namespace onnx_mlir::pim diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index a489b08..0777fb2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -101,10 +101,10 @@ struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel(rewriter, memCopyOp, targetOpt->getType(), + memCopyOp.getTargetOffset(), + memCopyOp.getSourceOffset(), *targetOpt, *sourceOpt, - memCopyOp.getTargetOffsetAttr(), - memCopyOp.getSourceOffsetAttr(), memCopyOp.getSizeAttr()); return success(); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td deleted file mode 100644 index bc920e3..0000000 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef PIM_BUFFERIZATION -#define PIM_BUFFERIZATION - -#ifndef OP_BASE -include "mlir/IR/PatternBase.td" -include "mlir/Dialect/MemRef/IR/MemRefOps.td" -include "src/Accelerators/PIM/Dialect/Pim/Pim.td" -#endif // OP_BASE - -def memrefCopyToPimMemCopyOp : Pat< - (CopyOp $src, $dst), - (PimMemCopyOp $dst, $src, - ConstantAttr, - ConstantAttr, - (NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src), - (returnType $dst)) ->; - -#endif // PIM_BUFFERIZATION diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 4d3379a..aa53879 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -6,7 +6,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Rewrite/PatternApplicator.h" #include "llvm/Support/Casting.h" @@ -27,7 +27,34 @@ namespace onnx_mlir { namespace { -#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc" +struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override { + if (!copyOp->getParentOfType() && !copyOp->getParentOfType()) + return failure(); + + auto sourceType = dyn_cast(copyOp.getSource().getType()); + auto targetType = dyn_cast(copyOp.getTarget().getType()); + if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape()) + return failure(); + if (sourceType.getElementType() != targetType.getElementType()) + return failure(); + + Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0); + IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource()); + pim::PimMemCopyOp::create(rewriter, + copyOp.getLoc(), + copyOp.getTarget().getType(), + zeroOffset, + zeroOffset, + copyOp.getTarget(), + copyOp.getSource(), + sizeAttr); + rewriter.eraseOp(copyOp); + return success(); + } +}; struct PimBufferizationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) @@ -44,6 +71,14 @@ private: LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const; }; +static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) { + if (!op || !op->getBlock()) + return failure(); + + rewriter.setInsertionPoint(op); + return applicator.matchAndRewrite(op, rewriter); +} + } // namespace void PimBufferizationPass::runOnOperation() { @@ -63,35 +98,60 @@ void PimBufferizationPass::runOnOperation() { } MLIRContext* ctx = moduleOp.getContext(); - ConversionTarget target(*ctx); - target.addLegalDialect(); + RewritePatternSet memrefCopyPatterns(ctx); + memrefCopyPatterns.add(ctx); + FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns)); + PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns); + memrefCopyApplicator.applyDefaultCostModel(); + PatternRewriter rewriter(ctx); - RewritePatternSet patterns(ctx); - populateWithGenerated(patterns); - - // Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies. - // Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering. - FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - bool hasFailed = false; - moduleOp.walk([&](Operation* op) { - if (!isa(op)) - return WalkResult::advance(); - if (failed(applyPartialConversion(op, target, frozenPatterns))) - hasFailed = true; - return WalkResult::skip(); + SmallVector copyWorklist; + moduleOp.walk([&](memref::CopyOp copyOp) { + if (copyOp->getParentOfType() || copyOp->getParentOfType()) + copyWorklist.push_back(copyOp); }); + + bool hasFailed = false; + for (memref::CopyOp copyOp : copyWorklist) { + if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) { + copyOp.emitOpError("failed to lower memref.copy inside PIM core body"); + hasFailed = true; + } + } if (hasFailed) { - moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization"); signalPassFailure(); return; } - // After this pass, executable PIM ops must only use contiguous/addressable memrefs. - // Later PIM codegen passes may verify this invariant but must not repair it. RewritePatternSet contiguityPatterns(ctx); populatePimContiguityNormalizationPatterns(contiguityPatterns); - if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) { - moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization"); + FrozenRewritePatternSet frozenContiguityPatterns(std::move(contiguityPatterns)); + PatternApplicator contiguityApplicator(frozenContiguityPatterns); + contiguityApplicator.applyDefaultCostModel(); + + SmallVector contiguityWorklist; + moduleOp.walk([&](Operation* op) { + if (isa(op)) + contiguityWorklist.push_back(op); + }); + + hasFailed = false; + for (Operation* op : contiguityWorklist) { + if (auto copyOp = dyn_cast(op); copyOp && pim::isNormalizedCopyOp(copyOp)) + continue; + if (auto copyOp = dyn_cast(op); copyOp && pim::isNormalizedCopyOp(copyOp)) + continue; + if (auto copyOp = dyn_cast(op); copyOp && pim::isNormalizedCopyOp(copyOp)) + continue; + + if (failed(applyPatternsOnce(op, contiguityApplicator, rewriter))) { + op->emitOpError("failed to normalize PIM copy contiguity"); + hasFailed = true; + } + } + + if (hasFailed) { + moduleOp.emitError("failed to normalize PIM copy contiguity during bufferization"); signalPassFailure(); return; } @@ -138,16 +198,28 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod }; if (auto memCopyOp = dyn_cast(op)) { + if (!pim::isNormalizedCopyOp(memCopyOp)) { + memCopyOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization"); + hasFailure = true; + } verifyOperand(memCopyOp.getTarget(), 0); verifyOperand(memCopyOp.getSource(), 1); return; } if (auto loadOp = dyn_cast(op)) { + if (!pim::isNormalizedCopyOp(loadOp)) { + loadOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization"); + hasFailure = true; + } verifyOperand(loadOp.getDeviceTarget(), 2); verifyOperand(loadOp.getHostSource(), 3); return; } if (auto storeOp = dyn_cast(op)) { + if (!pim::isNormalizedCopyOp(storeOp)) { + storeOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization"); + hasFailure = true; + } verifyOperand(storeOp.getHostTarget(), 2); verifyOperand(storeOp.getDeviceSource(), 3); return; diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp index 4e46707..a72db50 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp @@ -121,13 +121,14 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { rewriter.setInsertionPoint(mapOp); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto sizeInBytes = getShapedTypeSizeInBytes(initType); + Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0); pim::PimMemCopyOp::create(rewriter, mapOp.getLoc(), initType, + zeroOffset, + zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(sizeInBytes)); rewriter.eraseOp(mapOp); return success(); @@ -487,7 +488,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { if (!allocType || !allocType.hasStaticShape()) return failure(); - if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0) + auto targetOffset = resolveIndexValue(copyOp.getTargetOffset()); + auto sourceOffset = resolveIndexValue(copyOp.getSourceOffset()); + if (failed(targetOffset) || failed(sourceOffset) || *targetOffset != 0 || *sourceOffset != 0) return failure(); auto moduleOp = copyOp->getParentOfType(); diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 7fa8ac9..68969d6 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -13,6 +13,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -333,6 +334,12 @@ private: } if (auto storeOp = dyn_cast(op)) { + if (!pim::isNormalizedCopyOp(storeOp)) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization"); + }); + hasFailure = true; + } if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge)) || failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) { diagnostics.report(&op, [](Operation* illegalOp) { @@ -343,6 +350,12 @@ private: } if (auto loadOp = dyn_cast(op)) { + if (!pim::isNormalizedCopyOp(loadOp)) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization"); + }); + hasFailure = true; + } if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge)) || failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) { diagnostics.report(&op, [](Operation* illegalOp) { @@ -351,6 +364,22 @@ private: hasFailure = true; } } + + if (auto copyOp = dyn_cast(op)) { + if (!pim::isNormalizedCopyOp(copyOp)) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization"); + }); + hasFailure = true; + } + if (failed(resolveIndexValue(copyOp.getTargetOffset(), knowledge)) + || failed(resolveIndexValue(copyOp.getSourceOffset(), knowledge))) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen"); + }); + hasFailure = true; + } + } return success(!hasFailure); }); }