compact spatial IR through different new operations and dedicated syntax

fast spatial node merging with batch operations
This commit is contained in:
NiccoloN
2026-05-03 14:14:14 +02:00
parent 15e8edb9c4
commit b605585b1f
34 changed files with 4419 additions and 1445 deletions

View File

@@ -9,7 +9,6 @@ def SpatialDialect : Dialect {
let name = "spat";
let summary = "Dialect designed for deep learning computation in a spatial architecture";
let cppNamespace = "::onnx_mlir::spatial";
let useDefaultTypePrinterParser = 1;
}
class SpatOp<string mnemonic, list<Trait> traits = []> :
@@ -19,15 +18,6 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
def SpatTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<SpatialDialect, name, traits> {
let mnemonic = typeMnemonic;
}
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
@@ -48,10 +38,27 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
let assemblyFormat = [{
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
}];
def SpatComputeBatch : SpatOp<"compute_batch",
[SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compressed batch of independent equivalent compute lanes";
let arguments = (ins
I32Attr:$laneCount,
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
@@ -61,51 +68,66 @@ def SpatYieldOp : SpatOp<"yield", [Terminator]> {
Variadic<SpatTensor>:$outputs
);
let assemblyFormat = [{
$outputs attr-dict `:` type($outputs)
}];
let hasCustomAssemblyFormat = 1;
}
def SpatExtractRowsOp : SpatOp<"extract_rows", []> {
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatConcatOp : SpatOp<"concat", []> {
let summary = "Concatenate tensors with compact Spatial operand syntax";
let arguments = (ins
I64Attr:$axis,
Variadic<SpatTensor>:$inputs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
def SpatChannelNewOp : SpatOp<"channel_new", []> {
let summary = "Create a new virtual channel";
let results = (outs
SpatChannelType:$channel
);
let builders = [
OpBuilder<(ins ), [{
$_state.addTypes(SpatChannelType());
}]>
];
let assemblyFormat = [{
attr-dict
}];
}
def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a channel";
let summary = "Send a tensor through a logical channel";
let arguments = (ins
SpatChannelType:$channel,
I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId,
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
$input attr-dict `:` type($input)
}];
}
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a channel";
let summary = "Receive a tensor from a logical channel";
let arguments = (ins
SpatChannelType:$channel
I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId
);
let results = (outs
@@ -113,37 +135,70 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
attr-dict `:` type($output)
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let summary = "Broadcast a tensor through a shared channel buffer";
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
let summary = "Send multiple tensors through logical channels";
let arguments = (ins
SpatChannelType:$channel,
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
let summary = "Receive multiple tensors from logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let summary = "Receive a tensor from a shared channel buffer";
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins
SpatChannelType:$channel
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//