This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
}
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute<NumInputs>(
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphComputeBatch(
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
} // namespace onnx_mlir