From cf93caecd5c3fb0fcd2beba7f4ca3a8ecb78a26b Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sat, 30 May 2026 15:54:24 +0200 Subject: [PATCH] centralize logic for materializing contiguous memory into bufferization fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op --- src/PIM/Common/IR/AddressAnalysis.cpp | 37 +- src/PIM/Common/IR/BatchCoreUtils.cpp | 2 - src/PIM/Common/IR/ShapeUtils.cpp | 52 +++ src/PIM/Common/IR/ShapeUtils.hpp | 6 + src/PIM/Common/IR/SubviewUtils.cpp | 22 ++ src/PIM/Common/IR/SubviewUtils.hpp | 4 + src/PIM/Common/IR/WeightUtils.cpp | 24 +- src/PIM/Compiler/PimCodeGen.cpp | 27 +- src/PIM/Compiler/PimCodeGen.hpp | 1 - .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 149 -------- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 7 - .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 4 +- .../ONNXToSpatial/Patterns/NN/Pool.cpp | 6 +- .../BatchCoreLoweringPatterns.cpp | 39 +- src/PIM/Conversion/SpatialToPim/Common.cpp | 44 --- src/PIM/Conversion/SpatialToPim/Common.hpp | 14 - .../SpatialToPim/SpatialToPimPass.cpp | 12 - src/PIM/Dialect/Pim/Pim.td | 26 -- .../Bufferization/BufferizationUtils.cpp | 7 +- .../Bufferization/BufferizationUtils.hpp | 5 +- .../Transforms/Bufferization/CMakeLists.txt | 2 + .../Bufferization/ContiguityPatterns.cpp | 343 ++++++++++++++++++ .../Bufferization/ContiguityPatterns.hpp | 9 + .../OpBufferizationInterfaces.cpp | 61 +--- .../Bufferization/PimBufferizationPass.cpp | 88 ++++- .../StaticMemoryCoalescingPass.cpp | 4 + .../HostConstantFolding/Patterns/Subview.cpp | 342 +---------------- .../MaterializeHostConstantsPass.cpp | 40 +- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 88 ++--- 29 files changed, 642 insertions(+), 823 deletions(-) create mode 100644 src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp create mode 100644 src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index f49d26a..f477cb5 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -616,31 +616,38 @@ llvm::FailureOr compileContiguousAddressExprImpl(mlir::Valu if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return mlir::failure(); - llvm::SmallVector staticOffsets; - staticOffsets.reserve(subviewOp.getMixedOffsets().size()); llvm::SmallVector staticSizes; staticSizes.reserve(subviewOp.getMixedSizes().size()); llvm::SmallVector staticStrides; staticStrides.reserve(subviewOp.getMixedStrides().size()); - bool allStatic = true; + llvm::SmallVector staticOffsets; + staticOffsets.reserve(subviewOp.getMixedOffsets().size()); + bool hasOnlyStaticOffsets = true; for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) if (auto attr = mlir::dyn_cast(offset)) staticOffsets.push_back(mlir::cast(attr).getInt()); else - allStatic = false; - for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) - if (auto attr = mlir::dyn_cast(size)) - staticSizes.push_back(mlir::cast(attr).getInt()); - else - allStatic = false; - for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) - if (auto attr = mlir::dyn_cast(stride)) - staticStrides.push_back(mlir::cast(attr).getInt()); - else - allStatic = false; + hasOnlyStaticOffsets = false; + for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) { + auto attr = mlir::dyn_cast(size); + if (!attr) + return mlir::failure(); + staticSizes.push_back(mlir::cast(attr).getInt()); + } + for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) { + auto attr = mlir::dyn_cast(stride); + if (!attr) + return mlir::failure(); + staticStrides.push_back(mlir::cast(attr).getInt()); + } - if (allStatic) { + if (!isContiguousSubviewWithDynamicOffsets( + sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) { + return mlir::failure(); + } + + if (hasOnlyStaticOffsets) { if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides)) return mlir::failure(); diff --git a/src/PIM/Common/IR/BatchCoreUtils.cpp b/src/PIM/Common/IR/BatchCoreUtils.cpp index 3b7aa07..562b8d3 100644 --- a/src/PIM/Common/IR/BatchCoreUtils.cpp +++ b/src/PIM/Common/IR/BatchCoreUtils.cpp @@ -20,8 +20,6 @@ llvm::SmallVector getLaneChunkCoreIds(llvm::ArrayRef coreIds, bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) { if (mlir::isa(op)) return operandIndex == 3; - if (mlir::isa(op)) - return operandIndex == 1; if (mlir::isa(op)) return operandIndex == 2; return false; diff --git a/src/PIM/Common/IR/ShapeUtils.cpp b/src/PIM/Common/IR/ShapeUtils.cpp index 0e48410..112b8aa 100644 --- a/src/PIM/Common/IR/ShapeUtils.cpp +++ b/src/PIM/Common/IR/ShapeUtils.cpp @@ -111,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef srcShape, return true; } +bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef sourceShape, + llvm::ArrayRef mixedOffsets, + llvm::ArrayRef staticSizes, + llvm::ArrayRef staticStrides) { + if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size() + || sourceShape.size() != staticStrides.size()) { + return false; + } + + if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; })) + return false; + + auto reversedTriples = + llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes)); + + auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) { + auto [_sourceDim, offset, _size] = triple; + if (auto attr = mlir::dyn_cast(offset)) + return mlir::cast(attr).getInt() != 0; + return true; + }); + + if (firstNonZeroOrDynamicOffset != reversedTriples.end()) { + auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset; + if (auto attr = mlir::dyn_cast(offset)) { + int64_t staticOffset = mlir::cast(attr).getInt(); + if (size > sourceDim - staticOffset) + return false; + } + + ++firstNonZeroOrDynamicOffset; + for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it) + if (std::get<2>(*it) != 1) + return false; + } + + auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes)); + auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) { + auto [sourceDim, size] = pair; + return size != sourceDim; + }); + + if (firstDifferentSize != reversedSizes.end()) { + ++firstDifferentSize; + for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it) + if (std::get<1>(*it) != 1) + return false; + } + + return true; +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ShapeUtils.hpp b/src/PIM/Common/IR/ShapeUtils.hpp index 214846f..4aa08be 100644 --- a/src/PIM/Common/IR/ShapeUtils.hpp +++ b/src/PIM/Common/IR/ShapeUtils.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" #include "llvm/ADT/ArrayRef.h" @@ -30,4 +31,9 @@ bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef sizes, llvm::ArrayRef strides); +bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef sourceShape, + llvm::ArrayRef mixedOffsets, + llvm::ArrayRef staticSizes, + llvm::ArrayRef staticStrides); + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/SubviewUtils.cpp b/src/PIM/Common/IR/SubviewUtils.cpp index 32c3f4b..678b772 100644 --- a/src/PIM/Common/IR/SubviewUtils.cpp +++ b/src/PIM/Common/IR/SubviewUtils.cpp @@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) { } } +Value stripMemRefAddressingOps(Value value) { + while (true) { + if (auto subviewOp = value.getDefiningOp()) { + value = subviewOp.getSource(); + continue; + } + Value strippedValue = stripMemRefViewOps(value); + if (strippedValue == value) + return value; + value = strippedValue; + } +} + bool hasAllStaticSubviewParts(memref::SubViewOp subview) { return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) @@ -81,4 +94,13 @@ FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& return staticOffsets; } +bool isMemRefBaseAddressableValue(Value value) { + value = stripMemRefAddressingOps(value); + if (isa(value)) + return true; + + Operation* defOp = value.getDefiningOp(); + return defOp && isa(defOp); +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/SubviewUtils.hpp b/src/PIM/Common/IR/SubviewUtils.hpp index de53782..4f233ff 100644 --- a/src/PIM/Common/IR/SubviewUtils.hpp +++ b/src/PIM/Common/IR/SubviewUtils.hpp @@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value); mlir::Value stripMemRefViewOps(mlir::Value value); +mlir::Value stripMemRefAddressingOps(mlir::Value value); + bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview); llvm::FailureOr getStaticSubviewInfo(mlir::Value value); @@ -27,4 +29,6 @@ llvm::FailureOr getStaticSubviewInfo(mlir::Value value); /// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. llvm::FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info); +bool isMemRefBaseAddressableValue(mlir::Value value); + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 916dddf..0662b85 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -47,28 +47,6 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) { return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs)); } -mlir::Value stripWeightViewOps(mlir::Value value) { - while (true) { - if (auto subviewOp = value.getDefiningOp()) { - value = subviewOp.getSource(); - continue; - } - if (auto castOp = value.getDefiningOp()) { - value = castOp.getSource(); - continue; - } - if (auto collapseOp = value.getDefiningOp()) { - value = collapseOp.getSrc(); - continue; - } - if (auto expandOp = value.getDefiningOp()) { - value = expandOp.getSrc(); - continue; - } - return value; - } -} - template bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { auto weightArg = parentOp.getWeightArgument(weightIndex); @@ -159,7 +137,7 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) { - weight = stripWeightViewOps(weight); + weight = stripMemRefAddressingOps(weight); if (auto coreOp = mlir::dyn_cast_or_null(weightOwner)) { for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index a758a8f..60f2e8f 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -479,16 +479,6 @@ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticVa loadOp.getSize()); } -void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, - const StaticValueKnowledge& knowledge) const { - emitMemCopyOp("ld", - addressOf(loadOp.getDeviceTarget(), knowledge), - loadOp.getDeviceTargetOffset(), - addressOf(loadOp.getHostSource(), knowledge), - loadOp.getHostSourceOffset(), - loadOp.getSize()); -} - void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge); auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge); @@ -848,7 +838,6 @@ public: enum class CompiledCoreOpKind : uint8_t { Load, - LoadBatch, Store, Lmv, Receive, @@ -887,8 +876,6 @@ struct CompiledCoreNode { static FailureOr classifyCompiledCoreOpKind(Operation& op) { if (isa(op)) return CompiledCoreOpKind::Load; - if (isa(op)) - return CompiledCoreOpKind::LoadBatch; if (isa(op)) return CompiledCoreOpKind::Store; if (isa(op)) @@ -1027,9 +1014,6 @@ static LogicalResult executeCompiledCorePlan( case CompiledCoreOpKind::Load: coreCodeGen.codeGenLoadOp(cast(node.op), knowledge); break; - case CompiledCoreOpKind::LoadBatch: - coreCodeGen.codeGenLoadBatchOp(cast(node.op), knowledge); - break; case CompiledCoreOpKind::Store: coreCodeGen.codeGenStoreOp(cast(node.op), knowledge); break; @@ -1213,17 +1197,18 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: auto linkCoreWeights = [&](size_t coreId, ArrayRef weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes { auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); - if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { + if (auto error = sys::fs::create_directory(coreWeightsDirPath); error && error != std::errc::file_exists) { errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; return InvalidOutputFileAccess; } for (auto [slot, fileName] : llvm::enumerate(weightFiles)) { xbarsPerGroup.push_back(static_cast(slot)); - if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, - coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) { - errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " - << (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message() + std::string sourcePath = outputDirPath + "/weights/" + fileName; + std::string targetPath = coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin"; + sys::fs::remove(targetPath); + if (auto error = sys::fs::create_link(sourcePath, targetPath)) { + errs() << "Error creating link file: " << sourcePath << " to " << targetPath << "\nError:" << error.message() << '\n'; return InvalidOutputFileAccess; } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 981435f..67ab3c2 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -175,7 +175,6 @@ public: } void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; - void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 2b9c6ba..40eeec2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -1,11 +1,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Matchers.h" #include "llvm/ADT/SmallVector.h" -#include #include #include "IndexingUtils.hpp" @@ -20,35 +17,6 @@ using namespace mlir; namespace onnx_mlir { -static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) { - APInt lhsConst; - if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero()) - return rhs; - - APInt rhsConst; - if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero()) - return lhs; - - return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult(); -} - -static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) { - APInt factorConst; - if (auto attr = dyn_cast(factor)) - factorConst = cast(attr).getValue(); - else if (!matchPattern(cast(factor), m_ConstantInt(&factorConst))) - return arith::MulIOp::create(rewriter, loc, value, cast(factor)).getResult(); - - if (factorConst.isZero()) - return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - if (factorConst.isOne()) - return value; - - auto factorValue = - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue()); - return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult(); -} - bool hasStaticPositiveShape(ArrayRef shape) { return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); } @@ -124,39 +92,6 @@ SmallVector getStaticSizes(PatternRewriter& rewriter, ArrayRef strides) { - auto sourceType = dyn_cast(source.getType()); - if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() - || sourceType.getRank() != resultType.getRank()) - return false; - - for (OpFoldResult stride : strides) { - APInt strideValue; - if (auto attr = dyn_cast(stride)) { - if (cast(attr).getInt() != 1) - return false; - continue; - } - if (!matchPattern(cast(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne()) - return false; - } - - auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()), - llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend())); - auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { - auto [size, dimension] = sizeAndShape; - return size != dimension; - }); - if (firstDifferentSize == sizesAndShape.end()) - return true; - - ++firstDifferentSize; - return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) { - auto [size, _dimension] = sizeAndShape; - return size == 1; - }); -} - SmallVector sliceTensor( const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(tensorToSlice); @@ -222,90 +157,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri return slicesPerCore; } -Value materializeContiguousTensorSlice(Value source, - RankedTensorType resultType, - ArrayRef offsets, - ArrayRef strides, - ConversionPatternRewriter& rewriter, - Location loc) { - assert(resultType.hasStaticShape() && "expected static result type"); - size_t rank = static_cast(resultType.getRank()); - assert(offsets.size() == rank && "expected rank-matching offsets"); - assert(strides.size() == rank && "expected rank-matching strides"); - - SmallVector sizes; - sizes.reserve(resultType.getRank()); - for (int64_t size : resultType.getShape()) - sizes.push_back(rewriter.getIndexAttr(size)); - - if (isContiguousTensorSlice(source, resultType, strides)) - return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); - - if (resultType.getRank() == 0) - return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); - - Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult(); - SmallVector zeroIndices(resultType.getRank()); - for (Value& zeroIndex : zeroIndices) - zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - - SmallVector resultIndices; - resultIndices.reserve(resultType.getRank()); - - auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value { - if (dim == resultType.getRank()) { - SmallVector sourceIndices; - sourceIndices.reserve(resultType.getRank()); - for (unsigned idx = 0; idx < resultType.getRank(); ++idx) { - Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]); - Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc); - sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc)); - } - - SmallVector sourceOffsets; - SmallVector destinationOffsets; - SmallVector unitSizes; - SmallVector unitStrides; - sourceOffsets.reserve(resultType.getRank()); - destinationOffsets.reserve(resultType.getRank()); - unitSizes.reserve(resultType.getRank()); - unitStrides.reserve(resultType.getRank()); - for (Value index : sourceIndices) - sourceOffsets.push_back(index); - for (Value index : resultIndices) - destinationOffsets.push_back(index); - for (int64_t idx = 0; idx < resultType.getRank(); ++idx) { - unitSizes.push_back(rewriter.getIndexAttr(1)); - unitStrides.push_back(rewriter.getIndexAttr(1)); - } - - auto elementTensorType = - RankedTensorType::get(SmallVector(resultType.getRank(), 1), resultType.getElementType()); - Value elementSlice = - tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides) - .getResult(); - return tensor::InsertSliceOp::create( - rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides) - .getResult(); - } - - Value lower = zeroIndices[dim]; - Value upper = - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim)); - Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); - auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator}); - rewriter.setInsertionPointToStart(loop.getBody()); - resultIndices.push_back(loop.getInductionVar()); - Value updated = self(self, dim + 1, loop.getRegionIterArgs().front()); - resultIndices.pop_back(); - scf::YieldOp::create(rewriter, loc, updated); - rewriter.setInsertionPointAfter(loop); - return loop.getResult(0); - }; - - return buildLoopNest(buildLoopNest, 0, init); -} - Value extractAxisSlice( PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { auto sourceType = cast(source.getType()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index 785c906..84fd5a3 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -108,13 +108,6 @@ llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, llvm::DenseMap> sliceVectorPerCrossbarPerCore( const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); -mlir::Value materializeContiguousTensorSlice(mlir::Value source, - mlir::RankedTensorType resultType, - llvm::ArrayRef offsets, - llvm::ArrayRef strides, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc); - mlir::Value extractAxisSlice( mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 7c6e294..feefba9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -303,9 +303,11 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri static Value extractDynamicGemmBColumn( Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { SmallVector offsets {rewriter.getIndexAttr(0), column}; + SmallVector sizes {rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType()); - Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc); + Value columnSlice = + tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides).getResult(); SmallVector collapseReassociation { ReassociationIndices {0, 1} }; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index 7396daf..2cab22e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -23,7 +23,7 @@ using namespace mlir; namespace onnx_mlir { namespace { -static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { +static Value materializeTileTensor(ConversionPatternRewriter& rewriter, Location loc, Value tile) { auto tileType = cast(tile.getType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank())); @@ -319,7 +319,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value windowValue = tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); - windowValue = materializeContiguousTile(rewriter, loc, windowValue); + windowValue = materializeTileTensor(rewriter, loc, windowValue); reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue); } } @@ -335,7 +335,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value scaleSlice = tensor::ExtractSliceOp::create( rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); - scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice); + scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice); reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); } diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 87e6c8c..3ce8f29 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -59,10 +59,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter, ShapedType destinationType, IRMapping& mapper) { int64_t elementBytes = static_cast(getElementTypeSizeInBytes(destinationType.getElementType())); - SmallVector strides(destinationType.getRank(), 1); - ArrayRef shape = destinationType.getShape(); - for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim) - strides[dim] = strides[dim + 1] * shape[dim + 1]; + SmallVector strides = computeRowMajorStrides(destinationType.getShape()); Value totalOffset; Location loc = insertSlice.getLoc(); @@ -162,14 +159,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex); auto newArgType = cast(newArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); - auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - newArg, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - getTensorSizeInBytesAttr(rewriter, newArg)) + Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); + auto copied = pim::PimMemCopyHostToDevOp::create(rewriter, + loc, + outputBuffer.getType(), + zeroOffset, + zeroOffset, + outputBuffer, + newArg, + getTensorSizeInBytesAttr(rewriter, newArg)) .getOutput(); mapper.map(*oldArg, copied); } @@ -233,14 +231,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute } auto clonedType = cast(clonedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); - auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - clonedTensor, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - getTensorSizeInBytesAttr(rewriter, clonedTensor)) + Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); + auto copied = pim::PimMemCopyHostToDevOp::create(rewriter, + loc, + outputBuffer.getType(), + zeroOffset, + zeroOffset, + outputBuffer, + clonedTensor, + getTensorSizeInBytesAttr(rewriter, clonedTensor)) .getOutput(); mapper.map(toTensorOp.getResult(), copied); continue; diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index eea9698..4d304f4 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -1,10 +1,8 @@ #include "mlir/IR/ValueRange.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" #include -#include #include "Common.hpp" @@ -13,48 +11,6 @@ using namespace mlir; namespace onnx_mlir { -size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) { - /* - EXAMPLE RUN: - [1, 10, 3, 4] inputShape - [0, 2, 1, 3] offsets - - acc = 1 - --- - ret = 3 - acc = 4 - --- - ret = 3 + 4 * 1 = 7 - acc = 12 - --- - ret = 7 + 12 * 2 = 31 - acc = 120 - --- - ret = 31 + 120 * 0 = 31 - acc = 120 - */ - - size_t returnValue = 0; - - auto sliceOffsets = sliceOp.getStaticOffsets(); - auto inputDimSizes = inputShape.getShape(); - - assert(sliceOffsets.size() == inputDimSizes.size()); - - size_t accumulatedDimensionSize = 1; - - // Reverse iterate the two vectors - for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) { - auto curSliceOffset = std::get<0>(it); - auto curInputDimSize = std::get<1>(it); - - returnValue += accumulatedDimensionSize * curSliceOffset; - accumulatedDimensionSize *= curInputDimSize; - } - - return returnValue; -} - IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { return builder.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(cast(value.getType())))); } diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index 7377000..db87abe 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -6,20 +6,6 @@ namespace onnx_mlir { -/** - * \brief Get the offset of the ExtractSliceOp based on its static offsets and - * its static tensor input. - * - * The static offsets represent the starting position of the slice in each - * dimension, while the static tensor input gives its dimension size. - * - * \param sliceOp The ExtractSliceOp for which the actual offset needs to be - * calculated. - * \param inputShape The ShapedType of the ExtractSliceOp's input tensor - * \return The actual offset of the ExtractSliceOp. - */ -size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape); - mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); template diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 7e44e45..4eb11aa 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -83,18 +83,6 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter, auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); - - if (outputBuffer->getParentOfType()) - return PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - tensorType, - outputBuffer, - zeroValue, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - sizeAttr) - .getOutput(); - return PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr) .getOutput(); diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 1d8aa34..76a61b8 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -144,32 +144,6 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { }]; } -def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> { - let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core"; - - let arguments = (ins - PimTensor:$deviceTarget, - PimTensor:$hostSource, - I32Attr:$deviceTargetOffset, - I32Attr:$hostSourceOffset, - I32Attr:$size - ); - - let results = (outs - PimTensor:$output - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getDeviceTargetMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output) - }]; -} - def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from device memory into host memory"; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index 1635ac3..baef6c2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" @@ -10,8 +11,8 @@ using namespace bufferization; namespace onnx_mlir::pim { -Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { - if (succeeded(resolveContiguousAddress(memrefValue))) +Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { + if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) return memrefValue; auto shapedType = cast(memrefValue.getType()); @@ -30,7 +31,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& .getOutput(); } -Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { +Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { if (succeeded(resolveContiguousAddress(memrefValue))) return memrefValue; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp index de45384..d9bc034 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp @@ -5,8 +5,9 @@ namespace onnx_mlir::pim { -mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); -mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); +mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); +mlir::Value +allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); llvm::FailureOr getBufferOrValue(mlir::RewriterBase& rewriter, mlir::Value value, diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt index 5011fa9..62c90c8 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt @@ -6,6 +6,8 @@ add_pim_library(OMPimBufferization PimBufferizationPass.cpp BufferizationUtils.hpp BufferizationUtils.cpp + ContiguityPatterns.hpp + ContiguityPatterns.cpp OpBufferizationInterfaces.hpp OpBufferizationInterfaces.cpp Common.hpp diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp new file mode 100644 index 0000000..e2937a1 --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp @@ -0,0 +1,343 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "ContiguityPatterns.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir::pim { +namespace { + +static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) { + if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; })) + return false; + + return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides); +} + +static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) { + if (extraOffset == 0) + return baseOffset; + + if (auto attr = dyn_cast(baseOffset)) { + auto integerAttr = dyn_cast(attr); + assert(integerAttr && "expected integer offset attribute"); + return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset); + } + + auto value = cast(baseOffset); + auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset); + return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); +} + +static Value buildSubviewChunk(const StaticSubviewInfo& info, + ArrayRef outerIndices, + Location loc, + PatternRewriter& rewriter) { + SmallVector chunkOffsets; + SmallVector chunkSizes; + SmallVector chunkStrides; + chunkOffsets.reserve(info.offsets.size()); + chunkSizes.reserve(info.sizes.size()); + chunkStrides.reserve(info.strides.size()); + + for (size_t dim = 0; dim < info.sizes.size(); ++dim) { + int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0; + chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter)); + chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back())); + chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); + } + + return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); +} + +static SmallVector delinearizeIndexValue(Value linearIndex, + ArrayRef shape, + ArrayRef strides, + PatternRewriter& rewriter) { + SmallVector indices; + indices.reserve(shape.size()); + + Value remaining = linearIndex; + for (auto [_dim, stride] : llvm::enumerate(strides)) { + auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride); + Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); + indices.push_back(index); + remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); + } + + return indices; +} + +static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) { + if (auto attr = dyn_cast(baseOffset)) { + auto integerAttr = cast(attr); + if (integerAttr.getInt() == 0) + return extraOffset; + + auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt()); + return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); + } + + auto value = cast(baseOffset); + return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult(); +} + +static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info, + ArrayRef outerIndices, + Location loc, + PatternRewriter& rewriter) { + SmallVector chunkOffsets; + SmallVector chunkSizes; + SmallVector chunkStrides; + chunkOffsets.reserve(info.offsets.size()); + chunkSizes.reserve(info.sizes.size()); + chunkStrides.reserve(info.strides.size()); + + for (size_t dim = 0; dim < info.sizes.size(); ++dim) { + if (dim + 1 < info.sizes.size()) { + assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); + chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); + chunkSizes.push_back(rewriter.getIndexAttr(1)); + } + else { + chunkOffsets.push_back(info.offsets[dim]); + chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); + } + chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); + } + + return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); +} + +static Value buildContiguousChunk( + Value source, ArrayRef copyShape, ArrayRef outerIndices, Location loc, PatternRewriter& rewriter) { + SmallVector chunkOffsets; + SmallVector chunkSizes; + SmallVector chunkStrides; + chunkOffsets.reserve(copyShape.size()); + chunkSizes.reserve(copyShape.size()); + chunkStrides.reserve(copyShape.size()); + + for (size_t dim = 0; dim < copyShape.size(); ++dim) { + chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0)); + chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back())); + chunkStrides.push_back(rewriter.getIndexAttr(1)); + } + + return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides); +} + +template +static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, + Value dst, + Value src, + int64_t dstOffset, + int64_t srcOffset, + int64_t size, + bool allowLoopRewrite, + PatternRewriter& rewriter, + CreateCopyOp createCopyOp) { + auto srcSubview = getStaticSubviewInfo(src); + auto dstSubview = getStaticSubviewInfo(dst); + const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview); + const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview); + if (!splitSrc && !splitDst) + return failure(); + + auto sourceType = dyn_cast(src.getType()); + auto dstType = dyn_cast(dst.getType()); + if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) + return failure(); + if (sourceType.getElementType() != dstType.getElementType()) + return failure(); + + if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))) + return failure(); + if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))) + return failure(); + + ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); + if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) + return failure(); + + if (!hasByteSizedElementType(sourceType.getElementType())) + return failure(); + const int64_t elementByteWidth = static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); + + const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; + if (size != totalBytes) + return failure(); + + const int64_t sliceBytes = copyShape.back() * elementByteWidth; + if (sliceBytes <= 0) + return failure(); + + SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); + auto outerStrides = computeRowMajorStrides(outerShape); + const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); + const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape); + const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape); + + if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0 + && sourceType.getRank() == static_cast(copyShape.size()) + && dstType.getRank() == static_cast(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape) + && (splitDst || dstShapeMatchesCopyShape)) { + auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0); + auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices); + auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1); + + auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); + rewriter.setInsertionPointToStart(loop.getBody()); + + SmallVector outerIndices = + outerShape.empty() ? SmallVector {} + : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); + Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) + : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); + Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) + : buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter); + createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes); + return success(); + } + + rewriter.setInsertionPoint(copyOp); + for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { + SmallVector outerIndices = + outerShape.empty() ? SmallVector {} : delinearizeIndex(linearIndex, outerShape, outerStrides); + Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst; + Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src; + const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes; + const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes; + createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes); + } + + return success(); +} + +struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { + if (!copyOp->getParentOfType() && !copyOp->getParentOfType()) + return failure(); + + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getTarget(), + copyOp.getSource(), + copyOp.getTargetOffset(), + copyOp.getSourceOffset(), + copyOp.getSize(), + /*allowLoopRewrite=*/true, + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + pim::PimMemCopyOp::create(rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getTarget()); + return success(); + } +}; + +struct NormalizeHostSubviewLoadPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { + auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset()); + auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset()); + if (failed(dstOffset) || failed(srcOffset)) + return failure(); + + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getDeviceTarget(), + copyOp.getHostSource(), + *dstOffset, + *srcOffset, + copyOp.getSize(), + /*allowLoopRewrite=*/true, + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset); + Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset); + pim::PimMemCopyHostToDevOp::create(rewriter, + copyOp.getLoc(), + resultType, + dstOffsetValue, + srcOffsetValue, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getDeviceTarget()); + return success(); + } +}; + +struct NormalizeHostSubviewStorePattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override { + auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset()); + auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset()); + if (failed(dstOffset) || failed(srcOffset)) + return failure(); + + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getHostTarget(), + copyOp.getDeviceSource(), + *dstOffset, + *srcOffset, + copyOp.getSize(), + /*allowLoopRewrite=*/false, + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset); + Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset); + pim::PimMemCopyDevToHostOp::create(rewriter, + copyOp.getLoc(), + resultType, + dstOffsetValue, + srcOffsetValue, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getHostTarget()); + return success(); + } +}; + +} // namespace + +void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) { + patterns.add( + patterns.getContext()); +} + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp new file mode 100644 index 0000000..c47c56c --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir::pim { + +void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns); + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index ff2183f..a489b08 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -47,32 +47,6 @@ struct MemCopyHostToDevOpInterface } }; -struct MemCopyHostToDevBatchOpInterface -: DstBufferizableOpInterfaceExternalModel { - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto memCopyHostToDevOp = cast(op); - auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state); - if (failed(deviceTargetOpt)) - return failure(); - auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state); - if (failed(hostSourceOpt)) - return failure(); - - replaceOpWithNewBufferizedOp(rewriter, - memCopyHostToDevOp, - deviceTargetOpt->getType(), - *deviceTargetOpt, - *hostSourceOpt, - memCopyHostToDevOp.getDeviceTargetOffsetAttr(), - memCopyHostToDevOp.getHostSourceOffsetAttr(), - memCopyHostToDevOp.getSizeAttr()); - return success(); - } -}; - struct MemCopyDevToHostOpInterface : DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation* op, @@ -151,8 +125,9 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId()); + rewriter, op, contiguousOutput.getType(), contiguousOutput, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId()); return success(); } }; @@ -173,15 +148,16 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter)); + inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter)); } auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); + Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt); + rewriter, op, contiguousOutput.getType(), concatOp.getAxisAttr(), ValueRange(inputs), contiguousOutput); return success(); } }; @@ -206,7 +182,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel(rewriter, op, - materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), + materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getSizeAttr(), sendOp.getTargetCoreId()); return success(); @@ -431,8 +407,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput); @@ -475,8 +451,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput); @@ -514,9 +490,9 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); - Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); + Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); @@ -547,9 +523,9 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); - Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); + Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); @@ -583,8 +559,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput); return success(); @@ -599,7 +575,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimSendOp::attachInterface(*ctx); PimConcatOp::attachInterface(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx); - PimMemCopyHostToDevBatchOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); PimMemCopyOp::attachInterface(*ctx); PimTransposeOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 409b92c..f1f82a9 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -6,14 +6,14 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "Common/PimCommon.hpp" #include "Compiler/PimCodeGen.hpp" #include "Dialect/Pim/PimOps.hpp" -#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" +#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" @@ -40,6 +40,7 @@ struct PimBufferizationPass : PassWrapper(operand.getType())) + return; + if (succeeded(resolveContiguousAddress(operand)) || succeeded(compileContiguousAddressExpr(operand))) + return; + op->emitOpError() << "operand #" << operandIndex + << " is not backed by contiguous addressable storage after PIM bufferization"; + hasFailure = true; + }; + + if (auto memCopyOp = dyn_cast(op)) { + verifyOperand(memCopyOp.getTarget(), 0); + verifyOperand(memCopyOp.getSource(), 1); + return; + } + if (auto loadOp = dyn_cast(op)) { + verifyOperand(loadOp.getDeviceTarget(), 2); + verifyOperand(loadOp.getHostSource(), 3); + return; + } + if (auto storeOp = dyn_cast(op)) { + verifyOperand(storeOp.getHostTarget(), 2); + verifyOperand(storeOp.getDeviceSource(), 3); + return; + } + if (auto sendOp = dyn_cast(op)) { + verifyOperand(sendOp.getInput(), 0); + return; + } + if (auto receiveOp = dyn_cast(op)) { + verifyOperand(receiveOp.getOutputBuffer(), 0); + return; + } + if (auto concatOp = dyn_cast(op)) { + verifyOperand(concatOp.getOutputBuffer(), 0); + for (auto inputAndIndex : llvm::enumerate(concatOp.getInputs())) + verifyOperand(inputAndIndex.value(), inputAndIndex.index() + 1); + return; + } + if (isa(op)) { + for (auto operandAndIndex : llvm::enumerate(op->getOperands())) { + if (auto vmmOp = dyn_cast(op); vmmOp && operandAndIndex.index() == 0) + continue; + verifyOperand(operandAndIndex.value(), operandAndIndex.index()); + } + } + }); + + if (hasFailure) { + moduleOp.emitError("PIM bufferization must fully normalize executable runtime operand contiguity before codegen"); + return failure(); + } + return success(); +} + std::unique_ptr createPimBufferizationPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp index a1bd540..ec7647a 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp @@ -22,6 +22,10 @@ using namespace onnx_mlir::compact_asm; namespace onnx_mlir { namespace { +// This pass assumes bufferization has already normalized executable PIM +// operands. It only reuses compatible local allocations with non-overlapping +// lifetimes; it does not repair memory contiguity. + struct CoalescingReportRow { uint64_t numCandidates = 0; uint64_t numSkipped = 0; diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index 4864851..887b212 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -1,349 +1,12 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" - #include "../Common.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { -static bool isSubviewContiguous(const StaticSubviewInfo& info) { - if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; })) - return false; - - auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()), - llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend())); - auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { - auto [size, dimension] = sizeAndShape; - return size != dimension; - }); - if (firstDifferentSize == sizesAndShape.end()) - return true; - - ++firstDifferentSize; - return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) { - auto [size, _dimension] = sizeAndShape; - return size == 1; - }); -} - -static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) { - if (extraOffset == 0) - return baseOffset; - - if (auto attr = dyn_cast(baseOffset)) { - auto integerAttr = dyn_cast(attr); - assert(integerAttr && "expected integer offset attribute"); - return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset); - } - - auto value = cast(baseOffset); - auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset); - return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); -} - -static Value buildSubviewChunk(const StaticSubviewInfo& info, - ArrayRef outerIndices, - Location loc, - PatternRewriter& rewriter) { - SmallVector chunkOffsets; - SmallVector chunkSizes; - SmallVector chunkStrides; - chunkOffsets.reserve(info.offsets.size()); - chunkSizes.reserve(info.sizes.size()); - chunkStrides.reserve(info.strides.size()); - - for (size_t dim = 0; dim < info.sizes.size(); ++dim) { - int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0; - chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter)); - chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back())); - chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); - } - - return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); -} - -static SmallVector delinearizeIndexValue(Value linearIndex, - ArrayRef shape, - ArrayRef strides, - PatternRewriter& rewriter) { - SmallVector indices; - indices.reserve(shape.size()); - - Value remaining = linearIndex; - for (auto [_dim, stride] : llvm::enumerate(strides)) { - auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride); - Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); - indices.push_back(index); - remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); - } - - return indices; -} - -static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) { - if (auto attr = dyn_cast(baseOffset)) { - auto integerAttr = cast(attr); - if (integerAttr.getInt() == 0) - return extraOffset; - - auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt()); - return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); - } - - auto value = cast(baseOffset); - return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult(); -} - -static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info, - ArrayRef outerIndices, - Location loc, - PatternRewriter& rewriter) { - SmallVector chunkOffsets; - SmallVector chunkSizes; - SmallVector chunkStrides; - chunkOffsets.reserve(info.offsets.size()); - chunkSizes.reserve(info.sizes.size()); - chunkStrides.reserve(info.strides.size()); - - for (size_t dim = 0; dim < info.sizes.size(); ++dim) { - if (dim + 1 < info.sizes.size()) { - assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); - chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); - chunkSizes.push_back(rewriter.getIndexAttr(1)); - } - else { - chunkOffsets.push_back(info.offsets[dim]); - chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); - } - chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); - } - - return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); -} - -static Value buildContiguousChunk( - Value source, ArrayRef copyShape, ArrayRef outerIndices, Location loc, PatternRewriter& rewriter) { - SmallVector chunkOffsets; - SmallVector chunkSizes; - SmallVector chunkStrides; - chunkOffsets.reserve(copyShape.size()); - chunkSizes.reserve(copyShape.size()); - chunkStrides.reserve(copyShape.size()); - - for (size_t dim = 0; dim < copyShape.size(); ++dim) { - chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0)); - chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back())); - chunkStrides.push_back(rewriter.getIndexAttr(1)); - } - - return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides); -} - -template -static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, - Value dst, - Value src, - int64_t dstOffset, - int64_t srcOffset, - int64_t size, - bool allowLoopRewrite, - PatternRewriter& rewriter, - CreateCopyOp createCopyOp) { - auto srcSubview = getStaticSubviewInfo(src); - auto dstSubview = getStaticSubviewInfo(dst); - const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview); - const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview); - if (!splitSrc && !splitDst) - return failure(); - - auto sourceType = dyn_cast(src.getType()); - auto dstType = dyn_cast(dst.getType()); - if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) - return failure(); - if (sourceType.getElementType() != dstType.getElementType()) - return failure(); - - if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))) - return failure(); - if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))) - return failure(); - - ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); - if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) - return failure(); - - if (!hasByteSizedElementType(sourceType.getElementType())) - return failure(); - const int64_t elementByteWidth = static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); - - const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; - if (size != totalBytes) - return failure(); - - const int64_t sliceBytes = copyShape.back() * elementByteWidth; - if (sliceBytes <= 0) - return failure(); - - SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); - auto outerStrides = computeRowMajorStrides(outerShape); - const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); - - if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0 - && sourceType.getRank() == static_cast(copyShape.size()) - && dstType.getRank() == static_cast(copyShape.size())) { - auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0); - auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices); - auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1); - - auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); - rewriter.setInsertionPointToStart(loop.getBody()); - - SmallVector outerIndices = - outerShape.empty() ? SmallVector {} - : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); - Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) - : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); - Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) - : buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter); - createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes); - return success(); - } - - rewriter.setInsertionPoint(copyOp); - for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { - SmallVector outerIndices = - outerShape.empty() ? SmallVector {} : delinearizeIndex(linearIndex, outerShape, outerStrides); - Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst; - Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src; - const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes; - const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes; - createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes); - } - - return success(); -} - -// Splits core copies through subviews into contiguous copy chunks for codegen. -struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { - if (!copyOp->getParentOfType()) - return failure(); - - auto status = rewriteSubviewCopyLikeOp( - copyOp, - copyOp.getTarget(), - copyOp.getSource(), - copyOp.getTargetOffset(), - copyOp.getSourceOffset(), - copyOp.getSize(), - /*allowLoopRewrite=*/true, - rewriter, - [&]( - MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { - pim::PimMemCopyOp::create(rewriter, - copyOp.getLoc(), - resultType, - dst, - src, - rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - }); - if (failed(status)) - return failure(); - - rewriter.replaceOp(copyOp, copyOp.getTarget()); - return success(); - } -}; - -// Splits host-to-device subview loads into contiguous copy chunks. -struct RewriteHostSubviewLoadPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { - auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset()); - auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset()); - if (failed(dstOffset) || failed(srcOffset)) - return failure(); - - auto status = rewriteSubviewCopyLikeOp( - copyOp, - copyOp.getDeviceTarget(), - copyOp.getHostSource(), - *dstOffset, - *srcOffset, - copyOp.getSize(), - /*allowLoopRewrite=*/true, - rewriter, - [&]( - MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { - Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset); - Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset); - pim::PimMemCopyHostToDevOp::create(rewriter, - copyOp.getLoc(), - resultType, - dstOffsetValue, - srcOffsetValue, - dst, - src, - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - }); - if (failed(status)) - return failure(); - - rewriter.replaceOp(copyOp, copyOp.getDeviceTarget()); - return success(); - } -}; - -// Splits device-to-host subview stores into contiguous copy chunks. -struct RewriteHostSubviewStorePattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override { - auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset()); - auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset()); - if (failed(dstOffset) || failed(srcOffset)) - return failure(); - - auto status = rewriteSubviewCopyLikeOp( - copyOp, - copyOp.getHostTarget(), - copyOp.getDeviceSource(), - *dstOffset, - *srcOffset, - copyOp.getSize(), - /*allowLoopRewrite=*/false, - rewriter, - [&]( - MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { - Value dstOffset = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset); - Value srcOffset = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset); - pim::PimMemCopyDevToHostOp::create(rewriter, - copyOp.getLoc(), - resultType, - dstOffset, - srcOffset, - dst, - src, - rewriter.getI32IntegerAttr(static_cast(sliceBytes))); - }); - if (failed(status)) - return failure(); - - rewriter.replaceOp(copyOp, copyOp.getHostTarget()); - return success(); - } -}; - // Folds constant subviews used as core weights into standalone globals. struct FoldConstantCoreSubviewPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -392,10 +55,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern(patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 3aa45be..d7b3c88 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -7,12 +7,9 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" -#include - #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -85,31 +82,18 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, if (contiguousType != originalType) deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc); - Value copiedValue; - if constexpr (std::is_same_v) { - copiedValue = pim::PimMemCopyHostToDevBatchOp::create( - rewriter, - op->getLoc(), - originalType, - deviceDst, - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(static_cast(resolvedAddress->byteOffset)), - rewriter.getI32IntegerAttr(static_cast(totalBytes))) - .getOutput(); - } - else { - copiedValue = pim::PimMemCopyHostToDevOp::create( - rewriter, - op->getLoc(), - originalType, - getOrCreateIndexConstant(constantFolder, op, 0), - getOrCreateIndexConstant(constantFolder, op, static_cast(resolvedAddress->byteOffset)), - deviceDst, - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(static_cast(totalBytes))) - .getOutput(); - } + Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0); + Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset); + Value copiedValue = + pim::PimMemCopyHostToDevOp::create(rewriter, + op->getLoc(), + originalType, + zeroOffset, + hostOffset, + deviceDst, + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(static_cast(totalBytes))) + .getOutput(); cachedByType[originalType] = copiedValue; operand.set(copiedValue); diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 316648d..7fa8ac9 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -32,51 +32,30 @@ static bool isAddressOnlyHostOp(Operation* op) { memref::CopyOp>(op); } -// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity. -// Used for memref.copy operands which may be non-contiguous subviews. -static bool isBaseAddressableValue(Value value) { - while (true) { - if (isa(value)) - return true; - Operation* defOp = value.getDefiningOp(); - if (!defOp) - return false; - if (isa(defOp)) - return true; - if (auto subview = dyn_cast(defOp)) { - value = subview.getSource(); - continue; - } - if (auto cast = dyn_cast(defOp)) { - value = cast.getSource(); - continue; - } - if (auto collapse = dyn_cast(defOp)) { - value = collapse.getSrc(); - continue; - } - if (auto expand = dyn_cast(defOp)) { - value = expand.getSrc(); - continue; - } - return false; - } -} - static bool isCodegenAddressableValue(Value value) { auto resolvedAddress = resolveContiguousAddress(value); - if (failed(resolvedAddress)) + if (succeeded(resolvedAddress)) + return isa(resolvedAddress->base) + || isa(resolvedAddress->base.getDefiningOp()); + + auto compiledAddress = compileContiguousAddressExpr(value); + if (failed(compiledAddress)) return false; - return isa(resolvedAddress->base) - || isa(resolvedAddress->base.getDefiningOp()); + return isa(compiledAddress->base) + || isa(compiledAddress->base.getDefiningOp()); } static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) { auto resolvedAddress = resolveContiguousAddress(value, knowledge); - if (failed(resolvedAddress)) + if (succeeded(resolvedAddress)) + return isa(resolvedAddress->base) + || isa(resolvedAddress->base.getDefiningOp()); + + auto compiledAddress = compileContiguousAddressExpr(value); + if (failed(compiledAddress)) return false; - return isa(resolvedAddress->base) - || isa(resolvedAddress->base.getDefiningOp()); + return isa(compiledAddress->base) + || isa(compiledAddress->base.getDefiningOp()); } static bool isConstantGlobalView(Value value) { @@ -138,7 +117,6 @@ static bool isCoreWeightBlockArgument(Value value) { static bool isSupportedCoreInstructionOp(Operation* op) { return isa(op); } -static LogicalResult verifyBatchOpSemantics(Operation& op, - const StaticValueKnowledge& knowledge, - pim::CappedDiagnosticReporter& diagnostics) { - bool hasFailure = false; - auto reportFailure = [&](auto emitDiagnostic) { - diagnostics.report(&op, [&](Operation* illegalOp) { emitDiagnostic(illegalOp); }); - hasFailure = true; - }; - - if (auto memcpHdBatchOp = dyn_cast(op)) { - if (!isCodegenAddressableValue(memcpHdBatchOp.getHostSource(), knowledge)) { - reportFailure([](Operation* illegalOp) { - illegalOp->emitOpError("host operand #1 is not backed by contiguous addressable storage"); - }); - } - return success(!hasFailure); - } - - return success(!hasFailure); -} - struct VerificationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) @@ -343,7 +300,7 @@ private: } auto resolvedAddress = resolveContiguousAddress(operand, knowledge); - if (failed(resolvedAddress)) { + if (failed(resolvedAddress) && failed(compileContiguousAddressExpr(operand))) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; @@ -363,7 +320,9 @@ private: continue; } - if (!isa(resolvedAddress->base.getDefiningOp())) { + Value addressBase = + succeeded(resolvedAddress) ? resolvedAddress->base : compileContiguousAddressExpr(operand)->base; + if (!isa(addressBase.getDefiningOp())) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << "operand #" << operandIndex << " must be backed by device-local memory; materialize host values with " @@ -392,9 +351,6 @@ private: hasFailure = true; } } - - if (failed(verifyBatchOpSemantics(op, knowledge, diagnostics))) - hasFailure = true; return success(!hasFailure); }); } @@ -409,7 +365,7 @@ private: if (auto expandOp = dyn_cast(op)) return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics); if (auto copyOp = dyn_cast(op)) { - if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) { + if (!isMemRefBaseAddressableValue(copyOp.getSource()) || !isMemRefBaseAddressableValue(copyOp.getTarget())) { diagnostics.report(op, [](Operation* illegalOp) { illegalOp->emitOpError("depends on a value that is not backed by addressable storage"); }); @@ -432,7 +388,7 @@ private: } static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) { - if (isBaseAddressableValue(source)) + if (isMemRefBaseAddressableValue(source)) return success(); diagnostics.report(op, [](Operation* illegalOp) {