#ifndef SPATIAL_DIALECT_H #define SPATIAL_DIALECT_H include "mlir/IR/OpBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/AttrTypeBase.td" def SpatialDialect : Dialect { let name = "spat"; let summary = "Dialect designed for deep learning computation in a spatial architecture"; let cppNamespace = "::onnx_mlir::spatial"; let useDefaultTypePrinterParser = 1; } class SpatOp traits = []> : Op; // TODO maybe remove and use AnyRankedTensor directly def SpatTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; class SpatType traits = []> : TypeDef { let mnemonic = typeMnemonic; } def SpatChannelType : SpatType<"SpatChannel", "ch"> { let summary = "Virtual channel type"; } //===----------------------------------------------------------------------===// // Execution //===----------------------------------------------------------------------===// def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { let summary = "Compute region with attached constant weights"; let arguments = (ins Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; let assemblyFormat = [{ `[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body }]; } def SpatYieldOp : SpatOp<"yield", [Terminator]> { let summary = "Yield results from a compute region"; let arguments = (ins Variadic:$outputs ); let assemblyFormat = [{ $outputs attr-dict `:` type($outputs) }]; } //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// def SpatChannelNewOp : SpatOp<"channel_new", []> { let summary = "Create a new virtual channel"; let results = (outs SpatChannelType:$channel ); let builders = [ OpBuilder<(ins ), [{ $_state.addTypes(SpatChannelType()); }]> ]; let assemblyFormat = [{ attr-dict }]; } def SpatChannelSendOp : SpatOp<"channel_send", []> { let summary = "Send a tensor through a channel"; let arguments = (ins SpatChannelType:$channel, SpatTensor:$input ); let assemblyFormat = [{ $input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` }]; } def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { let summary = "Receive a tensor from a channel"; let arguments = (ins SpatChannelType:$channel ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ $channel attr-dict `:` `(` type($channel) `->` type($output) `)` }]; } def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> { let summary = "Broadcast a tensor through a shared channel buffer"; let arguments = (ins SpatChannelType:$channel, SpatTensor:$input ); let assemblyFormat = [{ $input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` }]; } def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> { let summary = "Receive a tensor from a shared channel buffer"; let arguments = (ins SpatChannelType:$channel ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ $channel attr-dict `:` `(` type($channel) `->` type($output) `)` }]; } //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// def SpatWeightedVMMOp : SpatOp<"Wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins I32Attr:$weightIndex, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatWeightedMVMOp : SpatOp<"Wmvm", []> { let summary = "Matrix-vector multiplication within a weighted compute operation"; let arguments = (ins I32Attr:$weightIndex, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVAddOp : SpatOp<"vadd", []> { let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVMulOp : SpatOp<"vmul", []> { let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatSumOp : SpatOp<"sum", []> { let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSigmoidOp : SpatOp<"sigmoid", []> { let summary = "Element-wise sigmoid activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatReluOp : SpatOp<"relu", []> { let summary = "Element-wise ReLU activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVMaxOp : SpatOp<"vmax", []> { let summary = "Element-wise max between two tensors"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } #endif // SPATIAL_DIALECT_H