#ifndef SPATIAL_DIALECT_H #define SPATIAL_DIALECT_H include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/RegionKindInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def SpatialDialect : Dialect { let name = "spat"; let summary = "Dialect designed for deep learning computation in a spatial architecture"; let cppNamespace = "::onnx_mlir::spatial"; } class SpatOp traits = []> : Op; // TODO maybe remove and use AnyRankedTensor directly def SpatTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; //===----------------------------------------------------------------------===// // Execution //===----------------------------------------------------------------------===// def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let summary = "Compute region with attached constant weights"; let arguments = (ins Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional> insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; let hasVerifier = 1; let hasFolder = 1; let hasCustomAssemblyFormat = 1; } def SpatComputeBatch : SpatOp<"compute_batch", [SingleBlock, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs"; let arguments = (ins I32Attr:$laneCount, Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getLaneArgument(); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx); std::optional> insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatInParallelOp : SpatOp<"in_parallel", [ Pure, Terminator, DeclareOpInterfaceMethods, HasParent<"SpatComputeBatch">, ] # GraphRegionNoTerminator.traits> { let summary = "Parallel combining terminator for resultful spat.compute_batch"; let regions = (region SizedRegion<1>:$region); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins)>, ]; let extraClassDeclaration = [{ ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); ::mlir::OpResult getParentResult(int64_t idx); }]; } def SpatYieldOp : SpatOp<"yield", [Terminator]> { let summary = "Yield results from a compute region"; let arguments = (ins Variadic:$outputs ); let hasCustomAssemblyFormat = 1; } def SpatExtractRowsOp : SpatOp<"extract_rows", []> { let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors"; let arguments = (ins SpatTensor:$input ); let results = (outs Variadic:$outputs ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatConcatOp : SpatOp<"concat", []> { let summary = "Concatenate tensors with compact Spatial operand syntax"; let arguments = (ins I64Attr:$axis, Variadic:$inputs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// def SpatChannelSendOp : SpatOp<"channel_send", []> { let summary = "Send a tensor through a logical channel"; let arguments = (ins Index:$channelId, Index:$sourceCoreId, Index:$targetCoreId, SpatTensor:$input ); let assemblyFormat = [{ $input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input) }]; } def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { let summary = "Receive a tensor from a logical channel"; let arguments = (ins Index:$channelId, Index:$sourceCoreId, Index:$targetCoreId ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output) }]; } def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> { let summary = "Send equal contiguous chunks of one tensor through logical channels"; let arguments = (ins Variadic:$channelIds, Variadic:$sourceCoreIds, Variadic:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; let assemblyFormat = [{ $input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input) }]; } def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> { let summary = "Receive equal contiguous chunks of one tensor from logical channels"; let arguments = (ins Variadic:$channelIds, Variadic:$sourceCoreIds, Variadic:$targetCoreIds ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output) }]; } def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> { let summary = "Send per-lane tensors through logical channels in a batch body"; let arguments = (ins Variadic:$channelIds, Variadic:$sourceCoreIds, Variadic:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; let assemblyFormat = [{ $input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input) }]; } def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> { let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let arguments = (ins Variadic:$channelIds, Variadic:$sourceCoreIds, Variadic:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; let assemblyFormat = [{ $input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input) }]; } def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> { let summary = "Receive a per-lane tensor through logical channels in a batch body"; let arguments = (ins Variadic:$channelIds, Variadic:$sourceCoreIds, Variadic:$targetCoreIds ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output) }]; } def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> { let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let arguments = (ins Variadic:$channelIds, Variadic:$sourceCoreIds, Variadic:$targetCoreIds ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output) }]; } //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// def SpatVMMOp : SpatOp<"wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins SpatTensor:$weight, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output) }]; } def SpatMVMOp : SpatOp<"Wmvm", []> { let summary = "Matrix-vector multiplication within a weighted compute operation"; let arguments = (ins SpatTensor:$weight, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output) }]; } def SpatVAddOp : SpatOp<"vadd", []> { let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVMulOp : SpatOp<"vmul", []> { let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatSumOp : SpatOp<"sum", []> { let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVAvgOp : SpatOp<"vavg", []> { let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSigmoidOp : SpatOp<"sigmoid", []> { let summary = "Element-wise sigmoid activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSoftmaxOp : SpatOp<"softmax", []> { let summary = "Softmax over the full input tensor slice"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatReluOp : SpatOp<"relu", []> { let summary = "Element-wise ReLU activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVMaxOp : SpatOp<"vmax", []> { let summary = "Element-wise max between two tensors"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } #endif // SPATIAL_DIALECT_H