constant fold linalg.map (generated from tensor.pad for padding)

refactor pim helpers in PimCommon
This commit is contained in:
NiccoloN
2026-03-20 20:51:20 +01:00
parent dbe646ac0d
commit 6933804003
14 changed files with 751 additions and 263 deletions

2
.gitignore vendored
View File

@@ -1,2 +1,4 @@
.idea
.claude
AGENTS.md
build

View File

@@ -20,7 +20,7 @@ add_onnx_mlir_library(OMPIMAccel
Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp
Pass/PimFoldHostConstantsPass.cpp
Pass/PimConstantFoldingPass.cpp
Pass/PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -12,7 +12,15 @@ using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
@@ -21,7 +29,11 @@ void createDirectory(const std::string& directory) {
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string dialectsDir = getOutputDir() + "/dialects";
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
@@ -143,4 +155,85 @@ FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, Rewr
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
} // namespace onnx_mlir

View File

@@ -6,6 +6,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Compiler/CompilerOptions.hpp"
@@ -32,4 +34,18 @@ mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::m
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir

View File

@@ -15,7 +15,6 @@
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -382,12 +381,9 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation and compute its inverse
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
SmallVector<int64_t> permInv(rank);
for (size_t i = 0; i < rank; i++)
permInv[perm[i]] = i;
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
@@ -412,10 +408,10 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}

View File

@@ -46,8 +46,8 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimFoldHostConstantsPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimHostVerificationPass());
pm.addPass(createMessagePass("Pim host verified"));
pm.addPass(createEmitPimJsonPass());

View File

@@ -54,7 +54,7 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
Operation* getEarliestUserWithinBlock(Value value) {
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();
assert(!users.empty());
@@ -67,23 +67,24 @@ Operation* getEarliestUserWithinBlock(Value value) {
return earliestUser;
}
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses =
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
return {operand, std::distance(operand.use_begin(), operand.use_end())};
});
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
Value result = operation->getResult(0);
mlir::Value result = operation->getResult(0);
auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands =
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())

View File

@@ -2,6 +2,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
@@ -39,71 +40,13 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
static bool isMemoryContiguous(const mlir::ArrayRef<int64_t> srcShape,
const mlir::ArrayRef<int64_t> offsets,
const mlir::ArrayRef<int64_t> sizes,
const mlir::ArrayRef<int64_t> strides) {
// Check that all strides are 1
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
// Check offsets from right to left:
// The first offset_n at position n different from 0:
// - limits all sizes to the left to 1
// - limits size_n to dimension_n - offset_n
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
// Check sizes from right to left:
// The first size_n at position n different from shape_n limits all sizes to the left to 1
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(mlir::Operation* op) {
return isa<mlir::tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op);
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
}
} // namespace onnx_mlir

View File

@@ -129,7 +129,7 @@ void SpatialToPimPass::runOnOperation() {
}
// Dump to file for debug
dumpModule(moduleOp, "pim");
dumpModule(moduleOp, "pim0");
}
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {

View File

@@ -89,7 +89,7 @@ void PimBufferizationPass::runOnOperation() {
annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug
dumpModule(moduleOp, "pim_buf");
dumpModule(moduleOp, "pim1_buff");
}
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {

View File

@@ -0,0 +1,618 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment = {}) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
if (!mapOp.getInputs().empty())
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
Attribute attr;
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
return failure();
return attr;
}
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return failure();
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
if (!initType || !initType.hasStaticShape())
return failure();
auto fillValue = getConstantMapYield(mapOp);
if (failed(fillValue))
return failure();
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
if (elementByteWidth == 0)
return failure();
size_t totalBytes = initType.getNumElements() * elementByteWidth;
rewriter.setInsertionPoint(mapOp);
pim::PimMemCopyHostToDevOp::create(rewriter,
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.eraseOp(mapOp);
return success();
}
};
struct StaticSubviewInfo {
Value source;
SmallVector<int64_t> sourceShape;
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
};
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
static int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (copyOp.getSize() != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = copyOp.getSrcOffset()
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = copyOp.getDstOffset()
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : copyOp.getDst(),
splitSrc ? srcSubview->source : copyOp.getSrc(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numElements = resultTensorType.getNumElements();
if (numElements < 0)
return failure();
Attribute fillValue;
SmallVector<ConstantSubviewCopy> copies;
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
SmallVector<Value> pendingAliases;
pendingAliases.push_back(allocOp.getResult());
while (!pendingAliases.empty()) {
Value alias = pendingAliases.pop_back_val();
for (Operation* user : alias.getUsers()) {
if (!visitedAliases.insert(user).second)
continue;
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
if (mapOp.getInit() != alias)
return failure();
auto maybeFillValue = getConstantMapYield(mapOp);
if (failed(maybeFillValue))
return failure();
if (fillValue && fillValue != *maybeFillValue)
return failure();
fillValue = *maybeFillValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
offsets.push_back(*staticOffset);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
strides.push_back(*staticStride);
}
for (Operation* subviewUser : subviewOp->getUsers()) {
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
if (copyOp.getTarget() != subviewOp.getResult())
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
if (failed(denseAttr))
return failure();
copies.push_back({*denseAttr, offsets, strides, copyOp});
continue;
}
return failure();
}
continue;
}
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
continue;
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
pendingAliases.push_back(castOp.getResult());
continue;
}
return failure();
}
}
if (!fillValue)
return failure();
SmallVector<Attribute> resultValues(numElements, fillValue);
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
});
for (const ConstantSubviewCopy& copy : copies) {
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
SmallVector<int64_t> sourceIndices =
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
SmallVector<int64_t> resultIndices;
resultIndices.reserve(sourceIndices.size());
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
resultIndices.push_back(offset + sourceIndex * stride);
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
resultValues[resultLinearIndex] = value;
}
}
return DenseElementsAttr::get(resultTensorType, resultValues);
}
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
MemRefType globalType = resultType;
auto newGlobal = createFoldedGlobal(moduleOp,
transposeOp.getLoc(),
globalType,
*transposedAttr,
sourceGlobal.getName().str() + "__folded_transpose",
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
return success();
}
};
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
if (failed(foldedAttr))
return failure();
auto allocType = cast<MemRefType>(allocOp.getType());
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
SmallVector<Operation*> opsToErase;
SmallVector<memref::CastOp> castsToReplace;
bool allLiveUsersAreCoreOps = true;
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
opsToErase.push_back(user);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
castsToReplace.push_back(castOp);
continue;
}
if (!isa<pim::PimCoreOp>(user))
return failure();
}
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
})) {
allLiveUsersAreCoreOps = false;
}
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
})) {
return failure();
}
if (allLiveUsersAreCoreOps) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
for (memref::CastOp castOp : castsToReplace)
preservedUsers.insert(castOp);
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
for (memref::CastOp castOp : castsToReplace) {
rewriter.setInsertionPoint(castOp);
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
rewriter.replaceOp(castOp, replacementCast);
if (allLiveUsersAreCoreOps)
markWeightAlways(replacementCast.getDefiningOp());
}
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
rewriter.eraseOp(subviewUser);
if (op->use_empty())
rewriter.eraseOp(op);
}
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim2_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -1,181 +0,0 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
MemRefType globalType = resultType;
auto globalName = sourceGlobal.getName().str() + "__folded_transpose";
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = sourceGlobal.getName().str() + "__folded_transpose_" + std::to_string(++suffix);
auto visibility = rewriter.getStringAttr("private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
auto newGlobal = memref::GlobalOp::create(moduleBuilder,
transposeOp.getLoc(),
globalName,
visibility,
globalType,
*transposedAttr,
/*constant=*/true,
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
return success();
}
};
struct PimFoldHostConstantsPass : PassWrapper<PimFoldHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimFoldHostConstantsPass)
StringRef getArgument() const override { return "fold-pim-host-constants-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns.add<FoldConstantTransposePattern>(context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config)))
signalPassFailure();
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimFoldHostConstantsPass() { return std::make_unique<PimFoldHostConstantsPass>(); }
} // namespace onnx_mlir

View File

@@ -15,7 +15,7 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimFoldHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();

View File

@@ -73,7 +73,7 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass);
registerPass(createPimFoldHostConstantsPass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimHostVerificationPass);
registerPass(createEmitPimJsonPass);
}