From 717ad160cd174685d56bf0f413743b02575a04a7 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 4 May 2026 09:20:43 +0200 Subject: [PATCH] Refactor PIM/Common (splitting in files, adding helpers, adding brief docs) --- .gitignore | 2 +- src/PIM/Common/CMakeLists.txt | 9 +- src/PIM/Common/IR/AddressAnalysis.cpp | 258 +++++++++ src/PIM/Common/IR/AddressAnalysis.hpp | 43 ++ src/PIM/Common/IR/CoreBlockUtils.cpp | 67 +++ src/PIM/Common/IR/CoreBlockUtils.hpp | 24 + src/PIM/Common/IR/EntryPointUtils.cpp | 45 ++ src/PIM/Common/IR/EntryPointUtils.hpp | 13 + src/PIM/Common/IR/ShapeUtils.cpp | 89 ++++ src/PIM/Common/IR/ShapeUtils.hpp | 22 + src/PIM/Common/IR/WeightUtils.cpp | 101 ++++ src/PIM/Common/IR/WeightUtils.hpp | 29 ++ src/PIM/Common/PimCommon.cpp | 575 --------------------- src/PIM/Common/PimCommon.hpp | 80 +-- src/PIM/Common/Support/DebugDump.cpp | 27 + src/PIM/Common/Support/DebugDump.hpp | 13 + src/PIM/Common/Support/Diagnostics.cpp | 41 ++ src/PIM/Common/Support/Diagnostics.hpp | 38 ++ src/PIM/Common/Support/FileSystemUtils.cpp | 24 + src/PIM/Common/Support/FileSystemUtils.hpp | 13 + 20 files changed, 863 insertions(+), 650 deletions(-) create mode 100644 src/PIM/Common/IR/AddressAnalysis.cpp create mode 100644 src/PIM/Common/IR/AddressAnalysis.hpp create mode 100644 src/PIM/Common/IR/CoreBlockUtils.cpp create mode 100644 src/PIM/Common/IR/CoreBlockUtils.hpp create mode 100644 src/PIM/Common/IR/EntryPointUtils.cpp create mode 100644 src/PIM/Common/IR/EntryPointUtils.hpp create mode 100644 src/PIM/Common/IR/ShapeUtils.cpp create mode 100644 src/PIM/Common/IR/ShapeUtils.hpp create mode 100644 src/PIM/Common/IR/WeightUtils.cpp create mode 100644 src/PIM/Common/IR/WeightUtils.hpp delete mode 100644 src/PIM/Common/PimCommon.cpp create mode 100644 src/PIM/Common/Support/DebugDump.cpp create mode 100644 src/PIM/Common/Support/DebugDump.hpp create mode 100644 src/PIM/Common/Support/Diagnostics.cpp create mode 100644 src/PIM/Common/Support/Diagnostics.hpp create mode 100644 src/PIM/Common/Support/FileSystemUtils.cpp create mode 100644 src/PIM/Common/Support/FileSystemUtils.hpp diff --git a/.gitignore b/.gitignore index 8bb3643..bf5221d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,4 @@ build cmake-build-debug cmake-build-release -**/__pycache__ +**/__* diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 7c370d2..8f1b44c 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -1,5 +1,12 @@ add_pim_library(OMPimCommon - PimCommon.cpp + IR/AddressAnalysis.cpp + IR/CoreBlockUtils.cpp + IR/EntryPointUtils.cpp + IR/ShapeUtils.cpp + IR/WeightUtils.cpp + Support/DebugDump.cpp + Support/Diagnostics.cpp + Support/FileSystemUtils.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp new file mode 100644 index 0000000..b296d29 --- /dev/null +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -0,0 +1,258 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" + +namespace onnx_mlir { + +mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) { + if (!moduleOp || !getGlobalOp) + return {}; + return moduleOp.lookupSymbol(getGlobalOp.getName()); +} + +namespace { + +mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) { + if (!knowledge) + return value; + + auto iter = knowledge->aliases.find(value); + while (iter != knowledge->aliases.end()) { + value = iter->second; + iter = knowledge->aliases.find(value); + } + return value; +} + +mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { + value = resolveAlias(value, knowledge); + + if (mlir::isa(value)) + return value; + + mlir::Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return value; + + if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { + if (auto result = mlir::dyn_cast(value)) + if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result)) + return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); + } + + if (auto castOp = mlir::dyn_cast(definingOp)) + return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); + if (auto collapseOp = mlir::dyn_cast(definingOp)) + return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge); + if (auto expandOp = mlir::dyn_cast(definingOp)) + return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge); + + return value; +} + +llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); + +llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { + value = resolveAlias(value, knowledge); + + if (knowledge) { + auto iter = knowledge->indexValues.find(value); + if (iter != knowledge->indexValues.end()) + return iter->second; + } + + auto constantOp = value.getDefiningOp(); + if (constantOp) { + if (auto integerAttr = mlir::dyn_cast(constantOp.getValue())) + return integerAttr.getInt(); + } + + mlir::Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return mlir::failure(); + + if (auto indexCastOp = mlir::dyn_cast(definingOp)) + return resolveIndexValueImpl(indexCastOp.getIn(), knowledge); + + if (auto addOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + return *lhs + *rhs; + } + + if (auto subOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + return *lhs - *rhs; + } + + if (auto mulOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + return *lhs * *rhs; + } + + if (auto divOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs) || *rhs == 0) + return mlir::failure(); + return static_cast(static_cast(*lhs) / static_cast(*rhs)); + } + + if (auto remOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs) || *rhs == 0) + return mlir::failure(); + return static_cast(static_cast(*lhs) % static_cast(*rhs)); + } + + return mlir::failure(); +} + +llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) { + if (auto attr = mlir::dyn_cast(ofr)) { + auto integerAttr = mlir::dyn_cast(attr); + if (!integerAttr) + return mlir::failure(); + return integerAttr.getInt(); + } + + return resolveIndexValueImpl(mlir::cast(ofr), knowledge); +} + +llvm::FailureOr resolveContiguousAddressImpl(mlir::Value value, + const StaticValueKnowledge* knowledge) { + int64_t byteOffset = 0; + value = resolveAlias(value, knowledge); + + while (true) { + if (mlir::isa(value)) + return ResolvedContiguousAddress {value, byteOffset}; + + mlir::Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return mlir::failure(); + + if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { + mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast(value)); + if (!tiedOperand) + return mlir::failure(); + value = resolveAlias(tiedOperand->get(), knowledge); + continue; + } + + if (auto forOp = mlir::dyn_cast(definingOp)) { + auto result = mlir::dyn_cast(value); + if (!result) + return mlir::failure(); + + auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); + mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); + if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { + if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 + && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { + value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); + continue; + } + } + + value = yieldedValue; + continue; + } + + if (auto subviewOp = mlir::dyn_cast(definingOp)) { + auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); + auto subviewType = mlir::dyn_cast(subviewOp.getType()); + if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) + return mlir::failure(); + + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; + offsets.reserve(subviewOp.getMixedOffsets().size()); + sizes.reserve(subviewOp.getMixedSizes().size()); + strides.reserve(subviewOp.getMixedStrides().size()); + + for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) { + auto resolvedOffset = resolveOpFoldResult(offset, knowledge); + if (failed(resolvedOffset)) + return mlir::failure(); + offsets.push_back(*resolvedOffset); + } + + for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) { + auto resolvedSize = resolveOpFoldResult(size, knowledge); + if (failed(resolvedSize)) + return mlir::failure(); + sizes.push_back(*resolvedSize); + } + + for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) { + auto resolvedStride = resolveOpFoldResult(stride, knowledge); + if (failed(resolvedStride)) + return mlir::failure(); + strides.push_back(*resolvedStride); + } + + if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) + return mlir::failure(); + + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; + value = resolveAlias(subviewOp.getSource(), knowledge); + continue; + } + + if (auto castOp = mlir::dyn_cast(definingOp)) { + value = resolveAlias(castOp.getSource(), knowledge); + continue; + } + if (auto collapseOp = mlir::dyn_cast(definingOp)) { + value = resolveAlias(collapseOp.getSrc(), knowledge); + continue; + } + if (auto expandOp = mlir::dyn_cast(definingOp)) { + value = resolveAlias(expandOp.getSrc(), knowledge); + continue; + } + + if (mlir::isa(definingOp)) + return ResolvedContiguousAddress {value, byteOffset}; + + return mlir::failure(); + } +} + +} // namespace + +llvm::FailureOr resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); } + +llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) { + return resolveIndexValueImpl(value, &knowledge); +} + +llvm::FailureOr resolveContiguousAddress(mlir::Value value) { + return resolveContiguousAddressImpl(value, nullptr); +} + +llvm::FailureOr resolveContiguousAddress(mlir::Value value, + const StaticValueKnowledge& knowledge) { + return resolveContiguousAddressImpl(value, &knowledge); +} + +mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) { + return resolveLoopCarriedAliasImpl(value, &knowledge); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/AddressAnalysis.hpp b/src/PIM/Common/IR/AddressAnalysis.hpp new file mode 100644 index 0000000..eb099bd --- /dev/null +++ b/src/PIM/Common/IR/AddressAnalysis.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DenseMap.h" + +namespace onnx_mlir { + +/// Describes a value as a base addressable object plus a statically known +/// byte offset after peeling aliases, casts, and contiguous subviews. +struct ResolvedContiguousAddress { + mlir::Value base; + int64_t byteOffset = 0; +}; + +/// Records compile-time facts used when interpreting address arithmetic and +/// loop-carried aliases inside PIM regions. +struct StaticValueKnowledge { + llvm::DenseMap indexValues; + llvm::DenseMap aliases; + + StaticValueKnowledge() {} +}; + +mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); + +/// Resolves a value to contiguous backing storage when that storage can be +/// proven statically from aliases, DPS ties, casts, and subviews. +llvm::FailureOr resolveContiguousAddress(mlir::Value value); +llvm::FailureOr resolveContiguousAddress(mlir::Value value, + const StaticValueKnowledge& knowledge); + +/// Statically evaluates index-like SSA values, including simple integer +/// arithmetic and loop facts recorded in `knowledge`. +llvm::FailureOr resolveIndexValue(mlir::Value value); +llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); + +/// Follows alias, view, and DPS chains to recover the backing value of a +/// loop-carried memref/result. +mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/CoreBlockUtils.cpp b/src/PIM/Common/IR/CoreBlockUtils.cpp new file mode 100644 index 0000000..09327bb --- /dev/null +++ b/src/PIM/Common/IR/CoreBlockUtils.cpp @@ -0,0 +1,67 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +namespace onnx_mlir { + +bool isCoreStaticAddressOp(mlir::Operation* op) { + return mlir::isa(op); +} + +mlir::LogicalResult +walkPimCoreBlock(mlir::Block& block, + const StaticValueKnowledge& knowledge, + llvm::function_ref callback) { + bool hasFailure = false; + for (mlir::Operation& op : block) { + if (mlir::isa(op) || isCoreStaticAddressOp(&op)) + continue; + + if (auto forOp = mlir::dyn_cast(op)) { + mlir::Block& loopBody = forOp.getRegion().front(); + auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge); + auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge); + auto step = resolveIndexValue(forOp.getStep(), knowledge); + if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) { + forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen"); + hasFailure = true; + continue; + } + + llvm::SmallVector iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end()); + for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) { + StaticValueKnowledge loopKnowledge = knowledge; + loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue; + for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues)) + loopKnowledge.aliases[iterArg] = iterValue; + + if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback))) + hasFailure = true; + + auto yieldOp = mlir::cast(loopBody.getTerminator()); + for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands())) + iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge); + } + continue; + } + + if (failed(callback(op, knowledge))) + hasFailure = true; + } + return mlir::success(!hasFailure); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/CoreBlockUtils.hpp b/src/PIM/Common/IR/CoreBlockUtils.hpp new file mode 100644 index 0000000..91fb7cf --- /dev/null +++ b/src/PIM/Common/IR/CoreBlockUtils.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "mlir/IR/Block.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/STLFunctionalExtras.h" + +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" + +namespace onnx_mlir { + +/// Returns true for ops in a `pim.core` body that only participate in static +/// address or index computation and therefore do not emit PIM instructions. +bool isCoreStaticAddressOp(mlir::Operation* op); + +/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when +/// their bounds are known and invoking `callback` only on instruction-emitting +/// operations. +mlir::LogicalResult +walkPimCoreBlock(mlir::Block& block, + const StaticValueKnowledge& knowledge, + llvm::function_ref callback); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/EntryPointUtils.cpp b/src/PIM/Common/IR/EntryPointUtils.cpp new file mode 100644 index 0000000..88fd10c --- /dev/null +++ b/src/PIM/Common/IR/EntryPointUtils.cpp @@ -0,0 +1,45 @@ +#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +namespace onnx_mlir { + +llvm::FailureOr getPimEntryFunc(mlir::ModuleOp moduleOp) { + if (!moduleOp) + return mlir::failure(); + + llvm::SmallVector entryPoints(moduleOp.getOps()); + if (entryPoints.size() > 1) { + moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size(); + return mlir::failure(); + } + if (!entryPoints.empty()) { + auto entryPointAttr = + entryPoints.front()->getAttrOfType(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName()); + if (!entryPointAttr) { + entryPoints.front().emitOpError("is missing the entry point function attribute"); + return mlir::failure(); + } + auto entryFunc = moduleOp.lookupSymbol(entryPointAttr.getLeafReference().getValue()); + if (!entryFunc) { + entryPoints.front().emitOpError("references an unknown entry function ") + << entryPointAttr.getLeafReference().getValue(); + return mlir::failure(); + } + return entryFunc; + } + + if (auto mainGraphFunc = moduleOp.lookupSymbol("main_graph")) + return mainGraphFunc; + + llvm::SmallVector nonExternalFuncs; + for (auto funcOp : moduleOp.getOps()) + if (!funcOp.isExternal()) + nonExternalFuncs.push_back(funcOp); + if (nonExternalFuncs.size() == 1) + return nonExternalFuncs.front(); + + moduleOp.emitError("could not resolve a unique PIM entry function"); + return mlir::failure(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/EntryPointUtils.hpp b/src/PIM/Common/IR/EntryPointUtils.hpp new file mode 100644 index 0000000..907c1a7 --- /dev/null +++ b/src/PIM/Common/IR/EntryPointUtils.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" + +namespace onnx_mlir { + +/// Resolves the function the PIM pipeline should treat as its entry point. +/// Prefers ONNX entry-point metadata, then `main_graph`, then the only +/// non-external function if the module is otherwise unambiguous. +llvm::FailureOr getPimEntryFunc(mlir::ModuleOp moduleOp); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ShapeUtils.cpp b/src/PIM/Common/IR/ShapeUtils.cpp new file mode 100644 index 0000000..33253cb --- /dev/null +++ b/src/PIM/Common/IR/ShapeUtils.cpp @@ -0,0 +1,89 @@ +#include "llvm/ADT/STLExtras.h" + +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" + +namespace onnx_mlir { + +llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape) { + llvm::SmallVector strides(shape.size(), 1); + for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) + strides[dim] = strides[dim + 1] * shape[dim + 1]; + return strides; +} + +llvm::SmallVector +delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides) { + llvm::SmallVector indices(shape.size(), 0); + for (auto [dim, stride] : llvm::enumerate(strides)) { + indices[dim] = linearIndex / stride; + linearIndex %= stride; + } + return indices; +} + +int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides) { + int64_t linearIndex = 0; + for (auto [index, stride] : llvm::zip_equal(indices, strides)) + linearIndex += index * stride; + return linearIndex; +} + +int64_t getNumElements(llvm::ArrayRef shape) { + int64_t numElements = 1; + for (int64_t dim : shape) + numElements *= dim; + return numElements; +} + +bool isMemoryContiguous(llvm::ArrayRef srcShape, + llvm::ArrayRef offsets, + llvm::ArrayRef sizes, + llvm::ArrayRef strides) { + if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) + return false; + + auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), + llvm::make_range(sizes.rbegin(), sizes.rend()), + llvm::make_range(srcShape.rbegin(), srcShape.rend())); + + auto firstNonZeroOffset = std::find_if( + offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { + auto [offset, _size, _dimension] = offsetAndSizeAndShape; + return offset != 0; + }); + + if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { + auto [offset, size, dimension] = *firstNonZeroOffset; + if (size > dimension - offset) + return false; + ++firstNonZeroOffset; + + if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { + auto [_offset, size, _dimension] = offsetAndSizeAndShape; + return size != 1; + })) + return false; + } + + auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), + llvm::make_range(srcShape.rbegin(), srcShape.rend())); + + auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { + auto [size, dimension] = sizeAndShape; + return size != dimension; + }); + + if (firstDifferentSize != sizesAndShape.end()) { + ++firstDifferentSize; + + if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { + auto [size, _dimension] = sizeAndShape; + return size != 1; + })) + return false; + } + + return true; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ShapeUtils.hpp b/src/PIM/Common/IR/ShapeUtils.hpp new file mode 100644 index 0000000..41d666a --- /dev/null +++ b/src/PIM/Common/IR/ShapeUtils.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace onnx_mlir { + +llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape); + +llvm::SmallVector +delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides); + +int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides); + +int64_t getNumElements(llvm::ArrayRef shape); + +bool isMemoryContiguous(llvm::ArrayRef srcShape, + llvm::ArrayRef offsets, + llvm::ArrayRef sizes, + llvm::ArrayRef strides); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp new file mode 100644 index 0000000..64104ba --- /dev/null +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -0,0 +1,101 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" + +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +namespace onnx_mlir { + +bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } + +void markWeightAlways(mlir::Operation* op) { + assert(op && "expected valid op"); + op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext())); +} + +namespace { + +template +bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { + bool found = false; + parentOp.walk([&](mlir::Operation* op) { + if (auto mvmOp = mlir::dyn_cast(op)) + found |= mvmOp.getWeightIndex() == weightIndex; + else if (auto vmmOp = mlir::dyn_cast(op)) + found |= vmmOp.getWeightIndex() == weightIndex; + }); + return found; +} + +template +void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref callback) { + auto weights = parentOp.getWeights(); + llvm::SmallSet visited; + auto walkWeightIndex = [&](unsigned weightIndex) { + if (weightIndex < weights.size() && visited.insert(weightIndex).second) + callback(parentOp->getOpOperand(weightIndex)); + }; + + parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); + parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); +} + +} // namespace + +bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) { + mlir::Operation* user = use.getOwner(); + unsigned operandIndex = use.getOperandNumber(); + + auto computeOp = mlir::dyn_cast(user); + if (!computeOp || operandIndex >= computeOp.getWeights().size()) + return false; + + return hasMvmVmmWeightUse(computeOp, operandIndex); +} + +bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { + llvm::SmallPtrSet visited; + auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool { + if (!visited.insert(currentValue).second) + return true; + if (currentValue.use_empty()) + return false; + + return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) { + if (isSpatialMvmVmmWeightUse(use)) + return true; + + mlir::Operation* user = use.getOwner(); + if (auto extractSliceOp = mlir::dyn_cast(user)) + return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self); + if (auto expandShapeOp = mlir::dyn_cast(user)) + return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); + if (auto collapseShapeOp = mlir::dyn_cast(user)) + return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); + if (auto transposeOp = mlir::dyn_cast(user)) + return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); + + return false; + }); + }; + + return walkUses(value, walkUses); +} + +void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback) { + assert(root && "expected valid root op"); + root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses(coreOp, callback); }); + root->walk([&](pim::PimCoreBatchOp coreBatchOp) { + auto weights = coreBatchOp.getWeights(); + for (auto weight : weights) + for (mlir::OpOperand& use : weight.getUses()) + if (use.getOwner() == coreBatchOp.getOperation()) + callback(use); + }); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/WeightUtils.hpp b/src/PIM/Common/IR/WeightUtils.hpp new file mode 100644 index 0000000..f0a1b2f --- /dev/null +++ b/src/PIM/Common/IR/WeightUtils.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" + +inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; + +namespace onnx_mlir { + +bool hasWeightAlways(mlir::Operation* op); + +/// Tags an op as producing a value that should stay materialized as a reusable +/// weight across later PIM lowering/codegen stages. +void markWeightAlways(mlir::Operation* op); + +bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); + +/// Returns true when a value flows only into Spatial weighted MVM/VMM operands, +/// allowing later passes to preserve it as a dedicated weight-like object. +bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); + +/// Visits weight operands consumed by Pim core ops/core batches so downstream +/// passes can identify globals that must remain weight-backed. +void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp deleted file mode 100644 index 0615d35..0000000 --- a/src/PIM/Common/PimCommon.cpp +++ /dev/null @@ -1,575 +0,0 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" - -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Support/raw_os_ostream.h" - -#include -#include - -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Compiler/CompilerOptions.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -std::string getOutputDir() { - if (outputBaseName.empty() || outputBaseName == "-") - return {}; - - size_t lastSlash = outputBaseName.find_last_of('/'); - if (lastSlash == std::string::npos) - return "."; - return outputBaseName.substr(0, lastSlash); -} - -void createDirectory(const std::string& directory) { - std::error_code errorCode; - std::filesystem::create_directories(directory, errorCode); - assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data()); -} - -void dumpModule(ModuleOp moduleOp, const std::string& name) { - std::string outputDir = getOutputDir(); - if (outputDir.empty()) - return; - - std::string dialectsDir = outputDir + "/dialects"; - createDirectory(dialectsDir); - - std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); - llvm::raw_os_ostream os(file); - OpPrintingFlags flags; - flags.elideLargeElementsAttrs(); - moduleOp.print(os, flags); - os.flush(); - file.close(); -} - -FailureOr getPimEntryFunc(ModuleOp moduleOp) { - if (!moduleOp) - return failure(); - - SmallVector entryPoints(moduleOp.getOps()); - if (entryPoints.size() > 1) { - moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size(); - return failure(); - } - if (!entryPoints.empty()) { - auto entryPointAttr = - entryPoints.front()->getAttrOfType(ONNXEntryPointOp::getEntryPointFuncAttrName()); - if (!entryPointAttr) { - entryPoints.front().emitOpError("is missing the entry point function attribute"); - return failure(); - } - auto entryFunc = moduleOp.lookupSymbol(entryPointAttr.getLeafReference().getValue()); - if (!entryFunc) { - entryPoints.front().emitOpError("references an unknown entry function ") - << entryPointAttr.getLeafReference().getValue(); - return failure(); - } - return entryFunc; - } - - if (auto mainGraphFunc = moduleOp.lookupSymbol("main_graph")) - return mainGraphFunc; - - SmallVector nonExternalFuncs; - for (auto funcOp : moduleOp.getOps()) - if (!funcOp.isExternal()) - nonExternalFuncs.push_back(funcOp); - if (nonExternalFuncs.size() == 1) - return nonExternalFuncs.front(); - - moduleOp.emitError("could not resolve a unique PIM entry function"); - return failure(); -} - -bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } - -void markWeightAlways(Operation* op) { - assert(op && "expected valid op"); - op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext())); -} - -namespace { - -template -bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { - bool found = false; - parentOp.walk([&](Operation* op) { - if (auto mvmOp = dyn_cast(op)) - found |= mvmOp.getWeightIndex() == weightIndex; - else if (auto vmmOp = dyn_cast(op)) - found |= vmmOp.getWeightIndex() == weightIndex; - }); - return found; -} - -template -void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref callback) { - auto weights = parentOp.getWeights(); - llvm::SmallSet visited; - auto walkWeightIndex = [&](unsigned weightIndex) { - if (weightIndex < weights.size() && visited.insert(weightIndex).second) - callback(parentOp->getOpOperand(weightIndex)); - }; - - parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); - parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); -} - -} // namespace - -bool isSpatialMvmVmmWeightUse(OpOperand& use) { - Operation* user = use.getOwner(); - unsigned operandIndex = use.getOperandNumber(); - - auto computeOp = dyn_cast(user); - if (!computeOp || operandIndex >= computeOp.getWeights().size()) - return false; - - return hasMvmVmmWeightUse(computeOp, operandIndex); -} - -bool hasOnlySpatialMvmVmmWeightUses(Value value) { - SmallPtrSet visited; - auto walkUses = [&](Value currentValue, auto& self) -> bool { - if (!visited.insert(currentValue).second) - return true; - if (currentValue.use_empty()) - return false; - - return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) { - if (isSpatialMvmVmmWeightUse(use)) - return true; - - Operation* user = use.getOwner(); - if (auto extractSliceOp = dyn_cast(user)) - return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self); - if (auto expandShapeOp = dyn_cast(user)) - return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); - if (auto collapseShapeOp = dyn_cast(user)) - return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); - if (auto transposeOp = dyn_cast(user)) - return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); - - return false; - }); - }; - - return walkUses(value, walkUses); -} - -void walkPimMvmVmmWeightUses(Operation* root, function_ref callback) { - assert(root && "expected valid root op"); - root->walk([&](pim::PimCoreOp coreOp) { - walkMvmVmmWeightUses(coreOp, callback); - }); - root->walk([&](pim::PimCoreBatchOp coreBatchOp) { - auto weights = coreBatchOp.getWeights(); - for (auto weight : weights) - for (OpOperand& use : weight.getUses()) - if (use.getOwner() == coreBatchOp.getOperation()) - callback(use); - }); -} - -memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { - if (!moduleOp || !getGlobalOp) - return {}; - return moduleOp.lookupSymbol(getGlobalOp.getName()); -} - -SmallVector computeRowMajorStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) - strides[dim] = strides[dim + 1] * shape[dim + 1]; - return strides; -} - -SmallVector delinearizeIndex(int64_t linearIndex, ArrayRef shape, ArrayRef strides) { - SmallVector indices(shape.size(), 0); - for (auto [dim, stride] : llvm::enumerate(strides)) { - indices[dim] = linearIndex / stride; - linearIndex %= stride; - } - return indices; -} - -int64_t linearizeIndex(ArrayRef indices, ArrayRef strides) { - int64_t linearIndex = 0; - for (auto [index, stride] : llvm::zip_equal(indices, strides)) - linearIndex += index * stride; - return linearIndex; -} - -int64_t getNumElements(ArrayRef shape) { - int64_t numElements = 1; - for (int64_t dim : shape) - numElements *= dim; - return numElements; -} - -bool isMemoryContiguous(ArrayRef srcShape, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) - return false; - - auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), - llvm::make_range(sizes.rbegin(), sizes.rend()), - llvm::make_range(srcShape.rbegin(), srcShape.rend())); - - auto firstNonZeroOffset = std::find_if( - offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { - auto [offset, _size, _dimension] = offsetAndSizeAndShape; - return offset != 0; - }); - - if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { - auto [offset, size, dimension] = *firstNonZeroOffset; - if (size > dimension - offset) - return false; - ++firstNonZeroOffset; - - if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { - auto [_offset, size, _dimension] = offsetAndSizeAndShape; - return size != 1; - })) - return false; - } - - auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), - llvm::make_range(srcShape.rbegin(), srcShape.rend())); - - auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { - auto [size, dimension] = sizeAndShape; - return size != dimension; - }); - - if (firstDifferentSize != sizesAndShape.end()) { - ++firstDifferentSize; - - if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { - auto [size, _dimension] = sizeAndShape; - return size != 1; - })) - return false; - } - - return true; -} - -static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) { - if (!knowledge) - return value; - - auto iter = knowledge->aliases.find(value); - while (iter != knowledge->aliases.end()) { - value = iter->second; - iter = knowledge->aliases.find(value); - } - return value; -} - -// Walks through view-like ops and DPS tied operands to find the "underlying" memref value -// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop -// and when propagating yielded values across iterations during static unrolling. -static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) { - value = resolveAlias(value, knowledge); - - if (auto blockArgument = dyn_cast(value)) - return value; - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return value; - - if (auto dpsDefiningOp = dyn_cast(definingOp)) { - if (auto result = dyn_cast(value)) - if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result)) - return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); - } - - if (auto castOp = dyn_cast(definingOp)) - return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); - if (auto collapseOp = dyn_cast(definingOp)) - return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge); - if (auto expandOp = dyn_cast(definingOp)) - return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge); - - return value; -} - -static FailureOr resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge); - -static FailureOr resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) { - value = resolveAlias(value, knowledge); - - if (knowledge) { - auto iter = knowledge->indexValues.find(value); - if (iter != knowledge->indexValues.end()) - return iter->second; - } - - auto constantOp = value.getDefiningOp(); - if (constantOp) { - if (auto integerAttr = dyn_cast(constantOp.getValue())) - return integerAttr.getInt(); - } - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return failure(); - - if (auto indexCastOp = dyn_cast(definingOp)) - return resolveIndexValueImpl(indexCastOp.getIn(), knowledge); - - if (auto addOp = dyn_cast(definingOp)) { - auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge); - auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge); - if (failed(lhs) || failed(rhs)) - return failure(); - return *lhs + *rhs; - } - - if (auto subOp = dyn_cast(definingOp)) { - auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge); - auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge); - if (failed(lhs) || failed(rhs)) - return failure(); - return *lhs - *rhs; - } - - if (auto mulOp = dyn_cast(definingOp)) { - auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge); - auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge); - if (failed(lhs) || failed(rhs)) - return failure(); - return *lhs * *rhs; - } - - if (auto divOp = dyn_cast(definingOp)) { - auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); - auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); - if (failed(lhs) || failed(rhs) || *rhs == 0) - return failure(); - return static_cast(static_cast(*lhs) / static_cast(*rhs)); - } - - if (auto remOp = dyn_cast(definingOp)) { - auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); - auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); - if (failed(lhs) || failed(rhs) || *rhs == 0) - return failure(); - return static_cast(static_cast(*lhs) % static_cast(*rhs)); - } - - return failure(); -} - -static FailureOr resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) { - if (auto attr = dyn_cast(ofr)) { - auto integerAttr = dyn_cast(attr); - if (!integerAttr) - return failure(); - return integerAttr.getInt(); - } - - return resolveIndexValueImpl(cast(ofr), knowledge); -} - -static FailureOr resolveContiguousAddressImpl(Value value, - const StaticValueKnowledge* knowledge) { - int64_t byteOffset = 0; - value = resolveAlias(value, knowledge); - - while (true) { - if (isa(value)) - return ResolvedContiguousAddress {value, byteOffset}; - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) - return failure(); - - if (auto dpsDefiningOp = dyn_cast(definingOp)) { - OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast(value)); - if (!tiedOperand) - return failure(); - value = resolveAlias(tiedOperand->get(), knowledge); - continue; - } - - if (auto forOp = dyn_cast(definingOp)) { - auto result = dyn_cast(value); - if (!result) - return failure(); - - // Trace the loop carry back to its underlying memref, then if that memref is the - // loop's own iter-arg we know the base comes from the corresponding init arg - // (every iteration yields the same backing memory in the DPS sense). - auto yieldOp = cast(forOp.getBody()->getTerminator()); - Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); - if (auto blockArgument = dyn_cast(yieldedValue)) { - if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 - && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { - value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); - continue; - } - } - - value = yieldedValue; - continue; - } - - if (auto subviewOp = dyn_cast(definingOp)) { - auto sourceType = dyn_cast(subviewOp.getSource().getType()); - auto subviewType = dyn_cast(subviewOp.getType()); - if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(subviewOp.getMixedOffsets().size()); - sizes.reserve(subviewOp.getMixedSizes().size()); - strides.reserve(subviewOp.getMixedStrides().size()); - - for (OpFoldResult offset : subviewOp.getMixedOffsets()) { - auto resolvedOffset = resolveOpFoldResult(offset, knowledge); - if (failed(resolvedOffset)) - return failure(); - offsets.push_back(*resolvedOffset); - } - - for (OpFoldResult size : subviewOp.getMixedSizes()) { - auto resolvedSize = resolveOpFoldResult(size, knowledge); - if (failed(resolvedSize)) - return failure(); - sizes.push_back(*resolvedSize); - } - - for (OpFoldResult stride : subviewOp.getMixedStrides()) { - auto resolvedStride = resolveOpFoldResult(stride, knowledge); - if (failed(resolvedStride)) - return failure(); - strides.push_back(*resolvedStride); - } - - if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) - return failure(); - - auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); - byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; - value = resolveAlias(subviewOp.getSource(), knowledge); - continue; - } - - if (auto castOp = dyn_cast(definingOp)) { - value = resolveAlias(castOp.getSource(), knowledge); - continue; - } - if (auto collapseOp = dyn_cast(definingOp)) { - value = resolveAlias(collapseOp.getSrc(), knowledge); - continue; - } - if (auto expandOp = dyn_cast(definingOp)) { - value = resolveAlias(expandOp.getSrc(), knowledge); - continue; - } - - if (isa(definingOp)) - return ResolvedContiguousAddress {value, byteOffset}; - - return failure(); - } -} - -FailureOr resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); } - -FailureOr resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) { - return resolveIndexValueImpl(value, &knowledge); -} - -FailureOr resolveContiguousAddress(Value value) { - return resolveContiguousAddressImpl(value, nullptr); -} - -FailureOr resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) { - return resolveContiguousAddressImpl(value, &knowledge); -} - -Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) { - return resolveLoopCarriedAliasImpl(value, &knowledge); -} - -bool isCoreStaticAddressOp(Operation* op) { - return isa(op); -} - -LogicalResult walkPimCoreBlock(Block& block, - const StaticValueKnowledge& knowledge, - llvm::function_ref callback) { - bool hasFailure = false; - for (Operation& op : block) { - if (isa(op) || isCoreStaticAddressOp(&op)) - continue; - - if (auto forOp = dyn_cast(op)) { - Block& loopBody = forOp.getRegion().front(); - auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge); - auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge); - auto step = resolveIndexValue(forOp.getStep(), knowledge); - if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) { - forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen"); - hasFailure = true; - continue; - } - - SmallVector iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end()); - for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) { - StaticValueKnowledge loopKnowledge = knowledge; - loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue; - for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues)) - loopKnowledge.aliases[iterArg] = iterValue; - - if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback))) - hasFailure = true; - - auto yieldOp = cast(loopBody.getTerminator()); - for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands())) - iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge); - } - continue; - } - - if (failed(callback(op, knowledge))) - hasFailure = true; - } - return success(!hasFailure); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 0ad4fd3..59c5ee1 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -11,83 +11,17 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" +#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" +#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" #include "src/Compiler/CompilerOptions.hpp" -inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; - namespace onnx_mlir { inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id"; -struct ResolvedContiguousAddress { - mlir::Value base; - int64_t byteOffset = 0; -}; - -struct StaticValueKnowledge { - llvm::DenseMap indexValues; - llvm::DenseMap aliases; - - StaticValueKnowledge() {} -}; - -std::string getOutputDir(); - -void createDirectory(const std::string& directory); - -void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); - -llvm::FailureOr getPimEntryFunc(mlir::ModuleOp moduleOp); - -bool hasWeightAlways(mlir::Operation* op); - -void markWeightAlways(mlir::Operation* op); - -bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); -bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); - -void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); - -mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); - -llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape); - -llvm::SmallVector -delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides); - -int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides); - -int64_t getNumElements(llvm::ArrayRef shape); - -bool isMemoryContiguous(llvm::ArrayRef srcShape, - llvm::ArrayRef offsets, - llvm::ArrayRef sizes, - llvm::ArrayRef strides); - -llvm::FailureOr resolveContiguousAddress(mlir::Value value); -llvm::FailureOr resolveContiguousAddress(mlir::Value value, - const StaticValueKnowledge& knowledge); - -llvm::FailureOr resolveIndexValue(mlir::Value value); -llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); - -/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for -/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries. -mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); - -/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and -/// only contribute to static addressing or index computations (arith integer math, -/// memref view ops, memref.alloc, arith.constant). -bool isCoreStaticAddressOp(mlir::Operation* op); - -/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically -/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op -/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is -/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback -/// failure so callers can collect multiple diagnostics, but propagates the overall result. -mlir::LogicalResult -walkPimCoreBlock(mlir::Block& block, - const StaticValueKnowledge& knowledge, - llvm::function_ref callback); - } // namespace onnx_mlir diff --git a/src/PIM/Common/Support/DebugDump.cpp b/src/PIM/Common/Support/DebugDump.cpp new file mode 100644 index 0000000..c6a3593 --- /dev/null +++ b/src/PIM/Common/Support/DebugDump.cpp @@ -0,0 +1,27 @@ +#include "llvm/Support/raw_os_ostream.h" + +#include + +#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" +#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" + +namespace onnx_mlir { + +void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) { + std::string outputDir = getOutputDir(); + if (outputDir.empty()) + return; + + std::string dialectsDir = outputDir + "/dialects"; + createDirectory(dialectsDir); + + std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); + llvm::raw_os_ostream os(file); + mlir::OpPrintingFlags flags; + flags.elideLargeElementsAttrs(); + moduleOp.print(os, flags); + os.flush(); + file.close(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/Support/DebugDump.hpp b/src/PIM/Common/Support/DebugDump.hpp new file mode 100644 index 0000000..9f55182 --- /dev/null +++ b/src/PIM/Common/Support/DebugDump.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "mlir/IR/BuiltinOps.h" + +#include + +namespace onnx_mlir { + +/// Emits a MLIR snapshot under the current compiler output +/// directory for pass-level debugging. +void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/Support/Diagnostics.cpp b/src/PIM/Common/Support/Diagnostics.cpp new file mode 100644 index 0000000..3a5d0f3 --- /dev/null +++ b/src/PIM/Common/Support/Diagnostics.cpp @@ -0,0 +1,41 @@ +#include "llvm/ADT/STLExtras.h" + +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" + +namespace onnx_mlir::pim { + +mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) { + return op->emitOpError() << "requires statically shaped " << valueDescription; +} + +mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op, + llvm::StringRef valueDescription, + int64_t actualRank, + llvm::ArrayRef supportedRanks) { + auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription; + if (supportedRanks.empty()) + return diag; + + diag << "; supported rank"; + if (supportedRanks.size() != 1) + diag << 's'; + diag << ' '; + + llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; }); + return diag; +} + +mlir::InFlightDiagnostic +emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) { + return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`"; +} + +mlir::LogicalResult emitFileSystemError(mlir::Location loc, + llvm::StringRef action, + llvm::StringRef path, + const std::error_code& errorCode) { + mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message(); + return mlir::failure(); +} + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Common/Support/Diagnostics.hpp b/src/PIM/Common/Support/Diagnostics.hpp new file mode 100644 index 0000000..0e9e884 --- /dev/null +++ b/src/PIM/Common/Support/Diagnostics.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" + +#include + +namespace onnx_mlir::pim { + +/// Emits a consistent diagnostic for target paths that require static shapes. +mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription); + +/// Emits a consistent diagnostic for unsupported ranks while listing the ranks +/// accepted by the current lowering/codegen path. +mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op, + llvm::StringRef valueDescription, + int64_t actualRank, + llvm::ArrayRef supportedRanks); + +/// Emits a consistent diagnostic for missing symbol/global references. +mlir::InFlightDiagnostic +emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName); + +/// Converts a filesystem error into an MLIR failure diagnostic anchored at +/// the relevant IR location. +mlir::LogicalResult +emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode); + +template +mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr& value) { + return mlir::success(succeeded(value)); +} + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Common/Support/FileSystemUtils.cpp b/src/PIM/Common/Support/FileSystemUtils.cpp new file mode 100644 index 0000000..84717ea --- /dev/null +++ b/src/PIM/Common/Support/FileSystemUtils.cpp @@ -0,0 +1,24 @@ +#include + +#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" +#include "src/Compiler/CompilerOptions.hpp" + +namespace onnx_mlir { + +std::string getOutputDir() { + if (outputBaseName.empty() || outputBaseName == "-") + return {}; + + size_t lastSlash = outputBaseName.find_last_of('/'); + if (lastSlash == std::string::npos) + return "."; + return outputBaseName.substr(0, lastSlash); +} + +void createDirectory(const std::string& directory) { + std::error_code errorCode; + std::filesystem::create_directories(directory, errorCode); + assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/Support/FileSystemUtils.hpp b/src/PIM/Common/Support/FileSystemUtils.hpp new file mode 100644 index 0000000..34386d0 --- /dev/null +++ b/src/PIM/Common/Support/FileSystemUtils.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace onnx_mlir { + +/// Returns the directory that should hold PIM artifacts/debug dumps for the +/// current compiler invocation. +std::string getOutputDir(); + +void createDirectory(const std::string& directory); + +} // namespace onnx_mlir