This commit is contained in:
@@ -7,9 +7,13 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -49,6 +53,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
|
||||
template <typename Fn>
|
||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
struct SpatComputeBatchBodyArgs {
|
||||
mlir::Value lane;
|
||||
mlir::ValueRange weights;
|
||||
mlir::ValueRange inputs;
|
||||
mlir::ValueRange outputs;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename RewriterT>
|
||||
@@ -159,6 +170,96 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
int64_t laneCount,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(
|
||||
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
|
||||
|
||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||
for (mlir::Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
for (mlir::Value input : inputs) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(input.getLoc());
|
||||
}
|
||||
for (mlir::Type resultType : resultTypes) {
|
||||
blockArgTypes.push_back(resultType);
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* block =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::SpatComputeBatchBodyArgs args {
|
||||
block->getArgument(0),
|
||||
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
||||
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
||||
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())
|
||||
};
|
||||
|
||||
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(args);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
rewriter.eraseOp(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
}
|
||||
}
|
||||
|
||||
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> strides) {
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename BodyFn>
|
||||
mlir::Value materializeOrComputeUnary(mlir::Value input,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
BodyFn&& build) {
|
||||
auto&& buildFn = build;
|
||||
if (isCompileTimeComputable(input))
|
||||
return buildFn(input);
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||
mlir::Value result = buildFn(computeInput);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user