#ifndef SPATIAL_DIALECT_H #define SPATIAL_DIALECT_H include "mlir/IR/OpBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/AttrTypeBase.td" def SpatialDialect : Dialect { let name = "spat"; let summary = "Dialect designed for deep learning computation in a spatial architecture"; let cppNamespace = "::onnx_mlir::spatial"; } class SpatOp traits = []> : Op; // TODO maybe remove and use AnyRankedTensor directly def SpatTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; //===----------------------------------------------------------------------===// // Execution //===----------------------------------------------------------------------===// def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { let summary = "Compute region with attached constant weights"; let arguments = (ins Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; let hasFolder = 1; let hasCustomAssemblyFormat = 1; } def SpatComputeBatch : SpatOp<"compute_batch", [SingleBlock, AttrSizedOperandSegments]> { let summary = "Compressed batch of independent equivalent compute lanes"; let arguments = (ins I32Attr:$laneCount, Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatYieldOp : SpatOp<"yield", [Terminator]> { let summary = "Yield results from a compute region"; let arguments = (ins Variadic:$outputs ); let hasCustomAssemblyFormat = 1; } def SpatExtractRowsOp : SpatOp<"extract_rows", []> { let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors"; let arguments = (ins SpatTensor:$input ); let results = (outs Variadic:$outputs ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatConcatOp : SpatOp<"concat", []> { let summary = "Concatenate tensors with compact Spatial operand syntax"; let arguments = (ins I64Attr:$axis, Variadic:$inputs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// def SpatChannelSendOp : SpatOp<"channel_send", []> { let summary = "Send a tensor through a logical channel"; let arguments = (ins I64Attr:$channelId, I32Attr:$sourceCoreId, I32Attr:$targetCoreId, SpatTensor:$input ); let assemblyFormat = [{ $input attr-dict `:` type($input) }]; } def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { let summary = "Receive a tensor from a logical channel"; let arguments = (ins I64Attr:$channelId, I32Attr:$sourceCoreId, I32Attr:$targetCoreId ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ attr-dict `:` type($output) }]; } def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> { let summary = "Send multiple tensors through logical channels"; let arguments = (ins DenseI64ArrayAttr:$channelIds, DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$targetCoreIds, Variadic:$inputs ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> { let summary = "Receive multiple tensors from logical channels"; let arguments = (ins DenseI64ArrayAttr:$channelIds, DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$targetCoreIds ); let results = (outs Variadic:$outputs ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { let summary = "Send per-lane tensors through logical channels in a batch body"; let arguments = (ins DenseI64ArrayAttr:$channelIds, DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { let summary = "Receive a per-lane tensor through logical channels in a batch body"; let arguments = (ins DenseI64ArrayAttr:$channelIds, DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$targetCoreIds ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// def SpatWeightedVMMOp : SpatOp<"Wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins I32Attr:$weightIndex, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatWeightedMVMOp : SpatOp<"Wmvm", []> { let summary = "Matrix-vector multiplication within a weighted compute operation"; let arguments = (ins I32Attr:$weightIndex, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVAddOp : SpatOp<"vadd", []> { let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVMulOp : SpatOp<"vmul", []> { let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatSumOp : SpatOp<"sum", []> { let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVAvgOp : SpatOp<"vavg", []> { let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSigmoidOp : SpatOp<"sigmoid", []> { let summary = "Element-wise sigmoid activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSoftmaxOp : SpatOp<"softmax", []> { let summary = "Softmax over the full input tensor slice"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatReluOp : SpatOp<"relu", []> { let summary = "Element-wise ReLU activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVMaxOp : SpatOp<"vmax", []> { let summary = "Element-wise max between two tensors"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } #endif // SPATIAL_DIALECT_H