normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -47,46 +42,4 @@ FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> ax
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||
}
|
||||
|
||||
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
|
||||
if (multiplier == 0)
|
||||
return getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
||||
}
|
||||
|
||||
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||
if (divisor == 1)
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
}
|
||||
|
||||
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||
if (divisor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
||||
}
|
||||
|
||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
||||
if (auto attr = dyn_cast<Attribute>(value))
|
||||
return getOrCreateIndexConstant(
|
||||
rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return cast<Value>(value);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/FoldInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -21,21 +17,4 @@ int64_t normalizeIndex(int64_t index, int64_t dimSize);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||
|
||||
mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineExpr expr,
|
||||
mlir::ValueRange operands);
|
||||
|
||||
mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter,
|
||||
mlir::Operation* anchorOp,
|
||||
mlir::Value value,
|
||||
int64_t multiplier);
|
||||
|
||||
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||
|
||||
mlir::Value
|
||||
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||
|
||||
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user