This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
@@ -9,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
}
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute<NumInputs>(
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphComputeBatch(
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
}
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1;
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
}
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
@@ -89,18 +89,18 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);