Dynamic gemm/conv

This commit is contained in:
ilgeco
2026-05-28 18:00:14 +02:00
parent cbf7b235f1
commit 1ab489fe0a
17 changed files with 704 additions and 69 deletions
+22
View File
@@ -1,8 +1,10 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -79,4 +81,24 @@ Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFo
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
}
Value createAffineApplyOrFoldedConstant(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* anchorOp) {
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
APInt constantValue;
if (!matchPattern(operand, m_ConstantInt(&constantValue)))
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(constantValue.getSExtValue()));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateHostIndexConstant(anchorOp, constantResult.getInt(), rewriter);
}
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
} // namespace onnx_mlir
+7
View File
@@ -1,5 +1,6 @@
#pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
@@ -29,4 +30,10 @@ mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value,
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineMap map,
mlir::ValueRange operands,
mlir::Operation* anchorOp);
} // namespace onnx_mlir