faster pim host constant folding
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m33s
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m33s
This commit is contained in:
@@ -1,10 +1,45 @@
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct DenseSubviewKey {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> resultShape;
|
||||
|
||||
bool operator==(const DenseSubviewKey& other) const {
|
||||
return source == other.source && offsets == other.offsets && resultShape == other.resultShape;
|
||||
}
|
||||
};
|
||||
|
||||
struct DenseSubviewKeyInfo {
|
||||
static inline DenseSubviewKey getEmptyKey() {
|
||||
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getEmptyKey()}, {}};
|
||||
}
|
||||
|
||||
static inline DenseSubviewKey getTombstoneKey() {
|
||||
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getTombstoneKey()}, {}};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const DenseSubviewKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
|
||||
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
|
||||
}
|
||||
|
||||
static bool isEqual(const DenseSubviewKey& lhs, const DenseSubviewKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
@@ -35,6 +70,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment) {
|
||||
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||
if (!globalOp.getConstant() || globalOp.getType() != globalType || globalOp.getAlignmentAttr() != alignment
|
||||
|| !globalOp.getInitialValue())
|
||||
continue;
|
||||
|
||||
auto existingDenseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (existingDenseAttr == denseAttr)
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
@@ -53,6 +98,43 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
alignment);
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> foldDenseSubview(DenseElementsAttr denseAttr,
|
||||
ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> resultShape) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|
||||
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
|
||||
return failure();
|
||||
|
||||
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
|
||||
DenseSubviewKey key {denseAttr, SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
|
||||
SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
|
||||
if (auto cached = cache.find(key); cached != cache.end())
|
||||
return cached->second;
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(resultShape, sourceType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
if (numResultElements < 0)
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [offset, index] : llvm::zip_equal(staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(offset + index);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
cache.try_emplace(std::move(key), foldedAttr);
|
||||
return foldedAttr;
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||
llvm::StringRef nameStem,
|
||||
mlir::IntegerAttr alignment = {});
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAttr denseAttr,
|
||||
llvm::ArrayRef<int64_t> staticOffsets,
|
||||
llvm::ArrayRef<int64_t> resultShape);
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
@@ -419,32 +419,16 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||
resultValues[i] = sourceValues[srcLinear];
|
||||
}
|
||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||
if (failed(maybeFoldedAttr))
|
||||
return failure();
|
||||
foldedAttr = *maybeFoldedAttr;
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
|
||||
@@ -249,32 +249,15 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType =
|
||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
auto newGlobal =
|
||||
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
|
||||
Reference in New Issue
Block a user