faster pim host constant folding
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m33s
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m33s
This commit is contained in:
@@ -1,10 +1,45 @@
|
|||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/Hashing.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct DenseSubviewKey {
|
||||||
|
DenseElementsAttr source;
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> resultShape;
|
||||||
|
|
||||||
|
bool operator==(const DenseSubviewKey& other) const {
|
||||||
|
return source == other.source && offsets == other.offsets && resultShape == other.resultShape;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DenseSubviewKeyInfo {
|
||||||
|
static inline DenseSubviewKey getEmptyKey() {
|
||||||
|
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getEmptyKey()}, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline DenseSubviewKey getTombstoneKey() {
|
||||||
|
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getTombstoneKey()}, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned getHashValue(const DenseSubviewKey& key) {
|
||||||
|
return static_cast<unsigned>(
|
||||||
|
llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
|
||||||
|
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const DenseSubviewKey& lhs, const DenseSubviewKey& rhs) { return lhs == rhs; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Value stripMemRefCasts(Value value) {
|
Value stripMemRefCasts(Value value) {
|
||||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||||
value = castOp.getSource();
|
value = castOp.getSource();
|
||||||
@@ -35,6 +70,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
|||||||
DenseElementsAttr denseAttr,
|
DenseElementsAttr denseAttr,
|
||||||
StringRef nameStem,
|
StringRef nameStem,
|
||||||
IntegerAttr alignment) {
|
IntegerAttr alignment) {
|
||||||
|
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||||
|
if (!globalOp.getConstant() || globalOp.getType() != globalType || globalOp.getAlignmentAttr() != alignment
|
||||||
|
|| !globalOp.getInitialValue())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto existingDenseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
if (existingDenseAttr == denseAttr)
|
||||||
|
return globalOp;
|
||||||
|
}
|
||||||
|
|
||||||
auto globalName = nameStem.str();
|
auto globalName = nameStem.str();
|
||||||
unsigned suffix = 0;
|
unsigned suffix = 0;
|
||||||
while (moduleOp.lookupSymbol(globalName))
|
while (moduleOp.lookupSymbol(globalName))
|
||||||
@@ -53,6 +98,43 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
|||||||
alignment);
|
alignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<DenseElementsAttr> foldDenseSubview(DenseElementsAttr denseAttr,
|
||||||
|
ArrayRef<int64_t> staticOffsets,
|
||||||
|
ArrayRef<int64_t> resultShape) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|
||||||
|
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
|
||||||
|
DenseSubviewKey key {denseAttr, SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
|
||||||
|
SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
|
||||||
|
if (auto cached = cache.find(key); cached != cache.end())
|
||||||
|
return cached->second;
|
||||||
|
|
||||||
|
auto resultTensorType = RankedTensorType::get(resultShape, sourceType.getElementType());
|
||||||
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||||
|
if (numResultElements < 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> resultValues(numResultElements);
|
||||||
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||||
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(resultIndices.size());
|
||||||
|
for (auto [offset, index] : llvm::zip_equal(staticOffsets, resultIndices))
|
||||||
|
sourceIndices.push_back(offset + index);
|
||||||
|
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
cache.try_emplace(std::move(key), foldedAttr);
|
||||||
|
return foldedAttr;
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||||
value = stripMemRefCasts(value);
|
value = stripMemRefCasts(value);
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
|||||||
llvm::StringRef nameStem,
|
llvm::StringRef nameStem,
|
||||||
mlir::IntegerAttr alignment = {});
|
mlir::IntegerAttr alignment = {});
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAttr denseAttr,
|
||||||
|
llvm::ArrayRef<int64_t> staticOffsets,
|
||||||
|
llvm::ArrayRef<int64_t> resultShape);
|
||||||
|
|
||||||
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||||
|
|
||||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||||
|
|||||||
@@ -419,32 +419,16 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
|
|
||||||
DenseElementsAttr foldedAttr;
|
DenseElementsAttr foldedAttr;
|
||||||
if (succeeded(srcSubview)) {
|
if (succeeded(srcSubview)) {
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
return failure();
|
return failure();
|
||||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||||
if (failed(staticOffsets))
|
if (failed(staticOffsets))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
if (failed(maybeFoldedAttr))
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
return failure();
|
||||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
foldedAttr = *maybeFoldedAttr;
|
||||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
|
||||||
SmallVector<Attribute> resultValues(numResultElements);
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
|
||||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
|
||||||
SmallVector<int64_t> sourceIndices;
|
|
||||||
sourceIndices.reserve(resultIndices.size());
|
|
||||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
|
||||||
sourceIndices.push_back(off + idx);
|
|
||||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
|
||||||
resultValues[i] = sourceValues[srcLinear];
|
|
||||||
}
|
|
||||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
|||||||
@@ -249,32 +249,15 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
|||||||
if (failed(staticOffsets))
|
if (failed(staticOffsets))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||||
auto resultMemRefType =
|
auto resultMemRefType =
|
||||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
||||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
if (failed(foldedAttr))
|
||||||
|
return failure();
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
||||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
|
||||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
|
||||||
SmallVector<Attribute> resultValues(numResultElements);
|
|
||||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
|
||||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
|
||||||
SmallVector<int64_t> sourceIndices;
|
|
||||||
sourceIndices.reserve(resultIndices.size());
|
|
||||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
|
||||||
sourceIndices.push_back(off + idx);
|
|
||||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
|
||||||
}
|
|
||||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
|
||||||
|
|
||||||
auto newGlobal =
|
auto newGlobal =
|
||||||
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview");
|
||||||
markWeightAlways(newGlobal);
|
markWeightAlways(newGlobal);
|
||||||
|
|
||||||
rewriter.setInsertionPoint(subviewOp);
|
rewriter.setInsertionPoint(subviewOp);
|
||||||
|
|||||||
Reference in New Issue
Block a user