faster pim host constant folding
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m33s

This commit is contained in:
NiccoloN
2026-04-14 19:58:26 +02:00
parent 95ae93e07d
commit ae93d1c563
4 changed files with 94 additions and 41 deletions

View File

@@ -1,10 +1,45 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseSubviewKey {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> resultShape;
bool operator==(const DenseSubviewKey& other) const {
return source == other.source && offsets == other.offsets && resultShape == other.resultShape;
}
};
struct DenseSubviewKeyInfo {
static inline DenseSubviewKey getEmptyKey() {
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getEmptyKey()}, {}};
}
static inline DenseSubviewKey getTombstoneKey() {
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getTombstoneKey()}, {}};
}
static unsigned getHashValue(const DenseSubviewKey& key) {
return static_cast<unsigned>(
llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
}
static bool isEqual(const DenseSubviewKey& lhs, const DenseSubviewKey& rhs) { return lhs == rhs; }
};
} // namespace
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
@@ -35,6 +70,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment) {
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
if (!globalOp.getConstant() || globalOp.getType() != globalType || globalOp.getAlignmentAttr() != alignment
|| !globalOp.getInitialValue())
continue;
auto existingDenseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (existingDenseAttr == denseAttr)
return globalOp;
}
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
@@ -53,6 +98,43 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
alignment);
}
FailureOr<DenseElementsAttr> foldDenseSubview(DenseElementsAttr denseAttr,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> resultShape) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
return failure();
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
DenseSubviewKey key {denseAttr, SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
if (auto cached = cache.find(key); cached != cache.end())
return cached->second;
auto resultTensorType = RankedTensorType::get(resultShape, sourceType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
if (numResultElements < 0)
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [offset, index] : llvm::zip_equal(staticOffsets, resultIndices))
sourceIndices.push_back(offset + index);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
cache.try_emplace(std::move(key), foldedAttr);
return foldedAttr;
}
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);

View File

@@ -30,6 +30,10 @@ mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
llvm::StringRef nameStem,
mlir::IntegerAttr alignment = {});
llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAttr denseAttr,
llvm::ArrayRef<int64_t> staticOffsets,
llvm::ArrayRef<int64_t> resultShape);
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);

View File

@@ -419,32 +419,16 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];
}
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
if (failed(maybeFoldedAttr))
return failure();
foldedAttr = *maybeFoldedAttr;
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());

View File

@@ -249,32 +249,15 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
if (failed(staticOffsets))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
auto resultMemRefType =
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
if (failed(foldedAttr))
return failure();
auto newGlobal =
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);