#ifndef SPATIAL_DIALECT_H #define SPATIAL_DIALECT_H include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/RegionKindInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def SpatialDialect : Dialect { let name = "spat"; let summary = "Dialect designed for deep learning computation in a spatial architecture"; let cppNamespace = "::onnx_mlir::spatial"; } class SpatOp traits = []> : Op; // TODO maybe remove and use AnyRankedTensor directly def SpatTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; //===----------------------------------------------------------------------===// // Execution //===----------------------------------------------------------------------===// class SpatComputeLikeBase : SpatOp]> { let summary = "Compute region with attached constant weights"; let arguments = (ins Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; let hasFolder = 1; let hasCustomAssemblyFormat = 1; } def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> { let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional> insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; } def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> { let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional> insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; } class SpatComputeBatchLikeBase : SpatOp]> { let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs"; let arguments = (ins I32Attr:$laneCount, Variadic:$weights, Variadic:$inputs ); let results = (outs Variadic:$outputs ); let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> { let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getLaneArgument(); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx); std::optional> insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; } def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> { let extraClassDeclaration = [{ std::optional<::mlir::BlockArgument> getLaneArgument(); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx); std::optional> insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; } def SpatInParallelOp : SpatOp<"in_parallel", [ Pure, Terminator, DeclareOpInterfaceMethods, ] # GraphRegionNoTerminator.traits> { let summary = "Parallel combining terminator for resultful Spatial compute batches"; let regions = (region SizedRegion<1>:$region); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins)>, ]; let extraClassDeclaration = [{ ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); ::mlir::OpResult getParentResult(int64_t idx); }]; } def SpatYieldOp : SpatOp<"yield", [Terminator]> { let summary = "Yield results from a compute region"; let arguments = (ins Variadic:$outputs ); let hasCustomAssemblyFormat = 1; } def SpatExtractRowsOp : SpatOp<"extract_rows", []> { let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors"; let arguments = (ins SpatTensor:$input ); let results = (outs Variadic:$outputs ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } def SpatConcatOp : SpatOp<"concat", []> { let summary = "Concatenate tensors with compact Spatial operand syntax"; let arguments = (ins I64Attr:$axis, Variadic:$inputs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// // Planning //===----------------------------------------------------------------------===// def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> { let summary = "Structured Conv2D planning op that preserves logical ONNX geometry"; let arguments = (ins SpatTensor:$input, SpatTensor:$weight, Optional:$bias, DenseI64ArrayAttr:$pads, DenseI64ArrayAttr:$strides, DenseI64ArrayAttr:$dilations, I64Attr:$group, StrAttr:$logicalLayout ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; } def SpatReluPlanOp : SpatOp<"relu_plan", []> { let summary = "Layout-aware ReLU planning op"; let arguments = (ins SpatTensor:$input, StrAttr:$logicalLayout ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; } def SpatReconciliatorOp : SpatOp<"reconciliator", []> { let summary = "Passive logical-to-physical layout selection record"; let arguments = (ins SpatTensor:$input, StrAttr:$logicalLayout, StrAttr:$physicalLayout, DenseI64ArrayAttr:$fragmentOffsets, DenseI64ArrayAttr:$fragmentSizes, StrAttr:$indexMap ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; } def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> { let summary = "Explicit layout conversion or materialization barrier"; let arguments = (ins SpatTensor:$input, StrAttr:$logicalLayout, StrAttr:$sourcePhysicalLayout, StrAttr:$targetPhysicalLayout ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; } //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// def SpatChannelSendOp : SpatOp<"channel_send", []> { let summary = "Send a tensor through a logical channel"; let arguments = (ins Index:$channelId, Index:$sourceCoreId, Index:$targetCoreId, SpatTensor:$input ); let assemblyFormat = [{ $input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input) }]; } def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { let summary = "Receive a tensor from a logical channel"; let arguments = (ins Index:$channelId, Index:$sourceCoreId, Index:$targetCoreId ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output) }]; } //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// def SpatVMMOp : SpatOp<"wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins SpatTensor:$weight, SpatTensor:$input ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output) }]; } def SpatVVDMulOp : SpatOp<"vvdmul", []> { let summary = "Dot product between two runtime vectors"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVAddOp : SpatOp<"vadd", []> { let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVSubOp : SpatOp<"vsub", []> { let summary = "Element-wise subtraction between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVMulOp : SpatOp<"vmul", []> { let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } def SpatVAvgOp : SpatOp<"vavg", []> { let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSigmoidOp : SpatOp<"sigmoid", []> { let summary = "Element-wise sigmoid activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatSoftmaxOp : SpatOp<"softmax", []> { let summary = "Softmax over the full input tensor slice"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatReluOp : SpatOp<"relu", []> { let summary = "Element-wise ReLU activation"; let arguments = (ins SpatTensor:$input ); let results = (outs SpatTensor:$output ); let assemblyFormat = [{ `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } def SpatVMaxOp : SpatOp<"vmax", []> { let summary = "Element-wise max between two tensors"; let arguments = (ins SpatTensor:$lhs, SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); let hasVerifier = 1; let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } #endif // SPATIAL_DIALECT_H