add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
This commit is contained in:
@@ -1,15 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
@@ -58,14 +56,6 @@ inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
enum class MapOperations {
|
||||
None,
|
||||
ONNXSoftmaxOp,
|
||||
ONNXReluOp,
|
||||
ONNXLeakyReluOp,
|
||||
ONNXExpOp
|
||||
};
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
@@ -114,6 +104,38 @@ inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <size_t NumInputs, typename BodyFn>
|
||||
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
|
||||
Reference in New Issue
Block a user