879 lines
34 KiB
C++
879 lines
34 KiB
C++
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <memory>
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
static Value stripMemRefCasts(Value value) {
|
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
|
value = castOp.getSource();
|
|
return value;
|
|
}
|
|
|
|
static Value stripMemRefViewOps(Value value) {
|
|
while (true) {
|
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
|
value = castOp.getSource();
|
|
continue;
|
|
}
|
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
|
value = collapseOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
|
value = expandOp.getSrc();
|
|
continue;
|
|
}
|
|
return value;
|
|
}
|
|
}
|
|
|
|
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
|
Location loc,
|
|
MemRefType globalType,
|
|
DenseElementsAttr denseAttr,
|
|
StringRef nameStem,
|
|
IntegerAttr alignment = {}) {
|
|
auto globalName = nameStem.str();
|
|
unsigned suffix = 0;
|
|
while (moduleOp.lookupSymbol(globalName))
|
|
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
|
|
|
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
|
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
|
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
|
return memref::GlobalOp::create(moduleBuilder,
|
|
loc,
|
|
globalName,
|
|
visibility,
|
|
globalType,
|
|
denseAttr,
|
|
/*constant=*/true,
|
|
alignment);
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
|
value = stripMemRefCasts(value);
|
|
|
|
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
|
if (!getGlobalOp)
|
|
return failure();
|
|
|
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
|
return failure();
|
|
|
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
|
if (!denseAttr)
|
|
return failure();
|
|
return denseAttr;
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!tensorType)
|
|
return failure();
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
if (static_cast<int64_t>(perms.size()) != rank)
|
|
return failure();
|
|
|
|
llvm::SmallBitVector seen(rank);
|
|
SmallVector<int64_t> transposedShape;
|
|
transposedShape.reserve(rank);
|
|
for (int64_t perm : perms) {
|
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
|
return failure();
|
|
seen.set(perm);
|
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
|
}
|
|
|
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
|
|
|
SmallVector<int64_t> originalStrides(rank, 1);
|
|
SmallVector<int64_t> transposedStrides(rank, 1);
|
|
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
|
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
|
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
|
}
|
|
|
|
SmallVector<int64_t> originalIndices(rank);
|
|
SmallVector<int64_t> transposedIndices(rank);
|
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
originalIndices[dim] = remaining / originalStrides[dim];
|
|
remaining %= originalStrides[dim];
|
|
}
|
|
|
|
for (int64_t dim = 0; dim < rank; ++dim)
|
|
transposedIndices[dim] = originalIndices[perms[dim]];
|
|
|
|
int64_t transposedLinearIndex = 0;
|
|
for (int64_t dim = 0; dim < rank; ++dim)
|
|
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
|
|
|
transposedValues[transposedLinearIndex] = value;
|
|
}
|
|
|
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
|
}
|
|
|
|
struct ConstantSubviewCopy {
|
|
DenseElementsAttr source;
|
|
SmallVector<int64_t> offsets;
|
|
SmallVector<int64_t> strides;
|
|
Operation* copyOp = nullptr;
|
|
};
|
|
|
|
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
|
if (!mapOp.getInputs().empty())
|
|
return failure();
|
|
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
|
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
|
return failure();
|
|
|
|
Attribute attr;
|
|
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
|
|
return failure();
|
|
return attr;
|
|
}
|
|
|
|
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
|
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
|
if (!coreOp)
|
|
return failure();
|
|
|
|
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
|
if (!initType || !initType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto fillValue = getConstantMapYield(mapOp);
|
|
if (failed(fillValue))
|
|
return failure();
|
|
|
|
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
|
|
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
|
|
|
|
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
|
|
if (!moduleOp)
|
|
return failure();
|
|
|
|
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(coreOp);
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
|
|
|
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
|
|
if (elementByteWidth == 0)
|
|
return failure();
|
|
size_t totalBytes = initType.getNumElements() * elementByteWidth;
|
|
|
|
rewriter.setInsertionPoint(mapOp);
|
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
|
mapOp.getLoc(),
|
|
initType,
|
|
mapOp.getInit(),
|
|
getGlobalOp.getResult(),
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
|
rewriter.eraseOp(mapOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct StaticSubviewInfo {
|
|
Value source;
|
|
SmallVector<int64_t> sourceShape;
|
|
SmallVector<int64_t> offsets;
|
|
SmallVector<int64_t> sizes;
|
|
SmallVector<int64_t> strides;
|
|
};
|
|
|
|
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
|
value = stripMemRefViewOps(value);
|
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
|
if (!subviewOp)
|
|
return failure();
|
|
|
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
return failure();
|
|
|
|
StaticSubviewInfo info;
|
|
info.source = source;
|
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
auto staticOffset = getConstantIntValue(offset);
|
|
if (!staticOffset)
|
|
return failure();
|
|
info.offsets.push_back(*staticOffset);
|
|
}
|
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
auto staticSize = getConstantIntValue(size);
|
|
if (!staticSize)
|
|
return failure();
|
|
info.sizes.push_back(*staticSize);
|
|
}
|
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
auto staticStride = getConstantIntValue(stride);
|
|
if (!staticStride)
|
|
return failure();
|
|
info.strides.push_back(*staticStride);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
static int64_t
|
|
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
|
SmallVector<int64_t> sourceIndices;
|
|
sourceIndices.reserve(info.sourceShape.size());
|
|
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
|
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
|
sourceIndices.push_back(info.offsets.back());
|
|
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
|
}
|
|
|
|
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
|
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
|
return failure();
|
|
|
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
|
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
|
|
const bool splitSrc = succeeded(srcSubview)
|
|
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
|
const bool splitDst = succeeded(dstSubview)
|
|
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
|
if (!splitSrc && !splitDst)
|
|
return failure();
|
|
|
|
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
|
|
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
|
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
return failure();
|
|
if (sourceType.getElementType() != dstType.getElementType())
|
|
return failure();
|
|
|
|
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
return failure();
|
|
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
|
return failure();
|
|
|
|
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
|
if (elementByteWidth <= 0)
|
|
return failure();
|
|
|
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
if (copyOp.getSize() != totalBytes)
|
|
return failure();
|
|
|
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
if (sliceBytes <= 0)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
|
|
rewriter.setInsertionPoint(copyOp);
|
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
|
SmallVector<int64_t> outerIndices =
|
|
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
|
const int64_t srcByteOffset = copyOp.getSrcOffset()
|
|
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
|
: linearIndex * sliceBytes);
|
|
const int64_t dstByteOffset = copyOp.getDstOffset()
|
|
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
|
: linearIndex * sliceBytes);
|
|
pim::PimMemCopyOp::create(rewriter,
|
|
copyOp.getLoc(),
|
|
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
|
splitDst ? dstSubview->source : copyOp.getDst(),
|
|
splitSrc ? srcSubview->source : copyOp.getSrc(),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
}
|
|
|
|
rewriter.replaceOp(copyOp, copyOp.getDst());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
|
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
|
|
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
|
|
const bool splitSrc = succeeded(srcSubview)
|
|
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
|
const bool splitDst = succeeded(dstSubview)
|
|
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
|
if (!splitSrc && !splitDst)
|
|
return failure();
|
|
|
|
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
|
|
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
|
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
return failure();
|
|
if (sourceType.getElementType() != dstType.getElementType())
|
|
return failure();
|
|
|
|
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
return failure();
|
|
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
|
return failure();
|
|
|
|
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
|
if (elementByteWidth <= 0)
|
|
return failure();
|
|
|
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
if (copyOp.getSize() != totalBytes)
|
|
return failure();
|
|
|
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
if (sliceBytes <= 0)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
|
|
rewriter.setInsertionPoint(copyOp);
|
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
|
SmallVector<int64_t> outerIndices =
|
|
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
|
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
|
|
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
|
: linearIndex * sliceBytes);
|
|
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
|
|
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
|
: linearIndex * sliceBytes);
|
|
pim::PimMemCopyHostToDevOp::create(
|
|
rewriter,
|
|
copyOp.getLoc(),
|
|
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
|
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
|
|
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
}
|
|
|
|
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
|
if (!allocType || !allocType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
|
const int64_t numElements = resultTensorType.getNumElements();
|
|
if (numElements < 0)
|
|
return failure();
|
|
|
|
Attribute fillValue;
|
|
SmallVector<ConstantSubviewCopy> copies;
|
|
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
|
|
SmallVector<Value> pendingAliases;
|
|
pendingAliases.push_back(allocOp.getResult());
|
|
|
|
while (!pendingAliases.empty()) {
|
|
Value alias = pendingAliases.pop_back_val();
|
|
for (Operation* user : alias.getUsers()) {
|
|
if (!visitedAliases.insert(user).second)
|
|
continue;
|
|
|
|
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
|
|
if (mapOp.getInit() != alias)
|
|
return failure();
|
|
auto maybeFillValue = getConstantMapYield(mapOp);
|
|
if (failed(maybeFillValue))
|
|
return failure();
|
|
if (fillValue && fillValue != *maybeFillValue)
|
|
return failure();
|
|
fillValue = *maybeFillValue;
|
|
continue;
|
|
}
|
|
|
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
|
|
SmallVector<int64_t> offsets;
|
|
SmallVector<int64_t> strides;
|
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
|
strides.reserve(subviewOp.getMixedStrides().size());
|
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
auto staticOffset = getConstantIntValue(offset);
|
|
if (!staticOffset)
|
|
return failure();
|
|
offsets.push_back(*staticOffset);
|
|
}
|
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
auto staticStride = getConstantIntValue(stride);
|
|
if (!staticStride)
|
|
return failure();
|
|
strides.push_back(*staticStride);
|
|
}
|
|
|
|
for (Operation* subviewUser : subviewOp->getUsers()) {
|
|
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
|
|
if (copyOp.getTarget() != subviewOp.getResult())
|
|
return failure();
|
|
|
|
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
|
|
if (failed(denseAttr))
|
|
return failure();
|
|
copies.push_back({*denseAttr, offsets, strides, copyOp});
|
|
continue;
|
|
}
|
|
return failure();
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
|
|
continue;
|
|
|
|
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
|
pendingAliases.push_back(castOp.getResult());
|
|
continue;
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
if (!fillValue)
|
|
return failure();
|
|
|
|
SmallVector<Attribute> resultValues(numElements, fillValue);
|
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
|
|
|
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
|
|
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
|
|
});
|
|
|
|
for (const ConstantSubviewCopy& copy : copies) {
|
|
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
|
|
if (!sourceType || !sourceType.hasStaticShape())
|
|
return failure();
|
|
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|
|
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
|
|
return failure();
|
|
|
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
|
|
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
|
|
SmallVector<int64_t> sourceIndices =
|
|
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
|
|
SmallVector<int64_t> resultIndices;
|
|
resultIndices.reserve(sourceIndices.size());
|
|
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
|
|
resultIndices.push_back(offset + sourceIndex * stride);
|
|
|
|
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
|
|
resultValues[resultLinearIndex] = value;
|
|
}
|
|
}
|
|
|
|
return DenseElementsAttr::get(resultTensorType, resultValues);
|
|
}
|
|
|
|
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
|
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
|
|
if (!resultType || !resultType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
|
|
if (!sourceGetGlobal)
|
|
return failure();
|
|
|
|
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
|
if (!moduleOp)
|
|
return failure();
|
|
|
|
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
|
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
|
return failure();
|
|
|
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
|
if (!denseAttr)
|
|
return failure();
|
|
|
|
SmallVector<int64_t> perms;
|
|
perms.reserve(transposeOp.getPerms().size());
|
|
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
|
|
perms.push_back(attr.getInt());
|
|
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
|
if (failed(transposedAttr))
|
|
return failure();
|
|
|
|
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
|
if (!llvm::equal(transposedShape, resultType.getShape()))
|
|
return failure();
|
|
|
|
MemRefType globalType = resultType;
|
|
|
|
auto newGlobal = createFoldedGlobal(moduleOp,
|
|
transposeOp.getLoc(),
|
|
globalType,
|
|
*transposedAttr,
|
|
sourceGlobal.getName().str() + "__folded_transpose",
|
|
sourceGlobal.getAlignmentAttr());
|
|
|
|
rewriter.setInsertionPoint(transposeOp);
|
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
|
|
|
bool isAlwaysWeight =
|
|
!transposeOp->getUsers().empty()
|
|
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
|
if (isAlwaysWeight) {
|
|
markWeightAlways(newGlobal);
|
|
markWeightAlways(newGetGlobal);
|
|
}
|
|
|
|
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
|
|
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
|
|
if (!moduleOp)
|
|
return failure();
|
|
|
|
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
|
|
if (failed(foldedAttr))
|
|
return failure();
|
|
|
|
auto allocType = cast<MemRefType>(allocOp.getType());
|
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
|
|
|
|
rewriter.setInsertionPoint(allocOp);
|
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
|
|
|
SmallVector<Operation*> opsToErase;
|
|
SmallVector<memref::CastOp> castsToReplace;
|
|
bool allLiveUsersAreCoreOps = true;
|
|
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
|
|
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
|
|
opsToErase.push_back(user);
|
|
continue;
|
|
}
|
|
|
|
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
|
castsToReplace.push_back(castOp);
|
|
continue;
|
|
}
|
|
|
|
if (!isa<pim::PimCoreOp>(user))
|
|
return failure();
|
|
}
|
|
|
|
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
|
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
|
})) {
|
|
allLiveUsersAreCoreOps = false;
|
|
}
|
|
|
|
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
|
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
|
})) {
|
|
return failure();
|
|
}
|
|
|
|
if (allLiveUsersAreCoreOps) {
|
|
markWeightAlways(newGlobal);
|
|
markWeightAlways(newGetGlobal);
|
|
}
|
|
|
|
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
|
|
for (memref::CastOp castOp : castsToReplace)
|
|
preservedUsers.insert(castOp);
|
|
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
|
|
|
|
for (memref::CastOp castOp : castsToReplace) {
|
|
rewriter.setInsertionPoint(castOp);
|
|
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
|
|
rewriter.replaceOp(castOp, replacementCast);
|
|
if (allLiveUsersAreCoreOps)
|
|
markWeightAlways(replacementCast.getDefiningOp());
|
|
}
|
|
|
|
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
|
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
|
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
|
|
rewriter.eraseOp(subviewUser);
|
|
if (op->use_empty())
|
|
rewriter.eraseOp(op);
|
|
}
|
|
|
|
if (allocOp.use_empty())
|
|
rewriter.eraseOp(allocOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
|
// Only match top-level memcp (not inside pim.core)
|
|
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
|
return failure();
|
|
|
|
// dst must be an alloc with static shape
|
|
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
|
if (!allocOp)
|
|
return failure();
|
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
|
if (!allocType || !allocType.hasStaticShape())
|
|
return failure();
|
|
|
|
// The copy must cover the full destination (offsets both zero)
|
|
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
|
return failure();
|
|
|
|
// Resolve the source through an optional subview to a get_global
|
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
|
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
|
|
|
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
|
if (!moduleOp)
|
|
return failure();
|
|
|
|
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
|
if (failed(denseAttr))
|
|
return failure();
|
|
|
|
// Build the folded dense attribute
|
|
DenseElementsAttr foldedAttr;
|
|
if (succeeded(srcSubview)) {
|
|
// Extract the sub-tensor from the source constant
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
|
if (!sourceType || !sourceType.hasStaticShape())
|
|
return failure();
|
|
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
|
|
return failure();
|
|
|
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
|
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
|
SmallVector<Attribute> resultValues(numResultElements);
|
|
|
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
|
SmallVector<int64_t> sourceIndices;
|
|
sourceIndices.reserve(resultIndices.size());
|
|
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
|
sourceIndices.push_back(off + idx);
|
|
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
|
resultValues[i] = sourceValues[srcLinear];
|
|
}
|
|
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
|
}
|
|
else {
|
|
// Direct copy from a global — just reuse its dense attribute
|
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
|
if (resultTensorType != denseAttr->getType())
|
|
return failure();
|
|
foldedAttr = *denseAttr;
|
|
}
|
|
|
|
// Verify that the alloc's remaining users are supported ops.
|
|
bool allLiveUsersAreCores = true;
|
|
for (Operation* user : allocOp->getUsers()) {
|
|
if (user == copyOp)
|
|
continue;
|
|
if (isa<memref::DeallocOp>(user))
|
|
continue;
|
|
if (isa<pim::PimCoreOp>(user))
|
|
continue;
|
|
if (isa<memref::SubViewOp>(user)) {
|
|
allLiveUsersAreCores = false;
|
|
continue;
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
|
if (allLiveUsersAreCores)
|
|
markWeightAlways(newGlobal);
|
|
|
|
rewriter.setInsertionPoint(allocOp);
|
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
|
if (allLiveUsersAreCores)
|
|
markWeightAlways(newGetGlobal);
|
|
|
|
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
|
rewriter.eraseOp(copyOp);
|
|
if (allocOp.use_empty())
|
|
rewriter.eraseOp(allocOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
|
// Only handle subviews whose users are all pim.core ops.
|
|
if (subviewOp.use_empty())
|
|
return failure();
|
|
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
|
return failure();
|
|
|
|
// Source must resolve to a constant get_global.
|
|
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
|
if (!moduleOp)
|
|
return failure();
|
|
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
|
if (failed(denseAttr))
|
|
return failure();
|
|
|
|
// Static subview info.
|
|
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
|
if (failed(subviewInfo))
|
|
return failure();
|
|
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
|
|
return failure();
|
|
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
|
if (!sourceType || !sourceType.hasStaticShape())
|
|
return failure();
|
|
|
|
// Build the contiguous result type.
|
|
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
|
auto resultMemRefType = MemRefType::get(
|
|
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
|
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
|
|
|
// Extract the sub-tensor.
|
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
|
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
|
SmallVector<Attribute> resultValues(numResultElements);
|
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
|
SmallVector<int64_t> sourceIndices;
|
|
sourceIndices.reserve(resultIndices.size());
|
|
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
|
sourceIndices.push_back(off + idx);
|
|
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
|
}
|
|
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
|
|
|
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
|
markWeightAlways(newGlobal);
|
|
|
|
rewriter.setInsertionPoint(subviewOp);
|
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
|
markWeightAlways(newGetGlobal);
|
|
|
|
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
|
|
|
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
|
|
|
LogicalResult initialize(MLIRContext* context) override {
|
|
RewritePatternSet owningPatterns(context);
|
|
for (auto* dialect : context->getLoadedDialects())
|
|
dialect->getCanonicalizationPatterns(owningPatterns);
|
|
for (RegisteredOperationName op : context->getRegisteredOperations())
|
|
op.getCanonicalizationPatterns(owningPatterns, context);
|
|
owningPatterns
|
|
.add<FoldConstantTransposePattern,
|
|
FoldConstantAllocPattern,
|
|
FoldConstantCoreMapPattern,
|
|
RewriteCoreSubviewCopyPattern,
|
|
RewriteHostSubviewLoadPattern,
|
|
FoldConstantMemCpPattern,
|
|
FoldConstantCoreSubviewPattern>(
|
|
context);
|
|
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
|
return success();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
GreedyRewriteConfig config;
|
|
config.enableFolding();
|
|
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
dumpModule(getOperation(), "pim2_folded");
|
|
}
|
|
|
|
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
|
|
|
} // namespace onnx_mlir
|