This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include "AttributeUtils.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
@@ -11,8 +12,6 @@
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -209,8 +208,7 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
block->getArgument(0),
|
||||
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
||||
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
||||
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())
|
||||
};
|
||||
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
|
||||
|
||||
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
@@ -252,8 +250,8 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
|
||||
if (isCompileTimeComputable(input))
|
||||
return buildFn(input);
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||
mlir::Value result = buildFn(computeInput);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
#include "IndexingUtils.hpp"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -85,7 +84,8 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val
|
||||
|
||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
||||
if (auto attr = dyn_cast<Attribute>(value))
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return getOrCreateIndexConstant(
|
||||
rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return cast<Value>(value);
|
||||
}
|
||||
|
||||
|
||||
@@ -26,8 +26,10 @@ mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
|
||||
mlir::AffineExpr expr,
|
||||
mlir::ValueRange operands);
|
||||
|
||||
mlir::Value
|
||||
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
|
||||
mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter,
|
||||
mlir::Operation* anchorOp,
|
||||
mlir::Value value,
|
||||
int64_t multiplier);
|
||||
|
||||
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
@@ -53,7 +53,9 @@ bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); }
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
@@ -98,11 +100,8 @@ FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<Arr
|
||||
return permutation;
|
||||
}
|
||||
|
||||
Value transposeMaybeInCompute(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
Value transposeMaybeInCompute(
|
||||
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||
};
|
||||
@@ -127,7 +126,8 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
|
||||
|
||||
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()
|
||||
|| sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
for (OpFoldResult stride : strides) {
|
||||
@@ -290,7 +290,8 @@ Value materializeContiguousTensorSlice(Value source,
|
||||
}
|
||||
|
||||
Value lower = zeroIndices[dim];
|
||||
Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
|
||||
Value upper =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
|
||||
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
@@ -316,7 +317,8 @@ Value extractAxisSlice(
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
||||
@@ -115,12 +115,8 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
int64_t size);
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
|
||||
Reference in New Issue
Block a user