Files
Raptor/src/PIM/Dialect/Spatial/Spatial.td
NiccoloN 93e20c1dfc standardize spatial and pim dialects
remove old unused stuff
2026-03-23 21:21:31 +01:00

294 lines
6.5 KiB
TableGen

#ifndef SPATIAL_DIALECT_H
#define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
def SpatialDialect : Dialect {
let name = "spat";
let summary = "Dialect designed for deep learning computation in a spatial architecture";
let cppNamespace = "::onnx_mlir::spatial";
let useDefaultTypePrinterParser = 1;
}
class SpatOp<string mnemonic, list<Trait> traits = []> :
Op<SpatialDialect, mnemonic, traits>;
// TODO maybe remove and use AnyRankedTensor directly
def SpatTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<SpatialDialect, name, traits> {
let mnemonic = typeMnemonic;
}
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let assemblyFormat = [{
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
}];
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
let assemblyFormat = [{
$outputs attr-dict `:` type($outputs)
}];
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
def SpatChannelNewOp : SpatOp<"channel_new", []> {
let summary = "Create a new virtual channel";
let results = (outs
SpatChannelType:$channel
);
let builders = [
OpBuilder<(ins ), [{
$_state.addTypes(SpatChannelType());
}]>
];
let assemblyFormat = [{
attr-dict
}];
}
def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a channel";
let arguments = (ins
SpatChannelType:$channel,
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a channel";
let arguments = (ins
SpatChannelType:$channel
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let summary = "Broadcast a tensor through a shared channel buffer";
let arguments = (ins
SpatChannelType:$channel,
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let summary = "Receive a tensor from a shared channel buffer";
let arguments = (ins
SpatChannelType:$channel
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVMulOp : SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatSumOp : SpatOp<"sum", []> {
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatReluOp : SpatOp<"relu", []> {
let summary = "Element-wise ReLU activation";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMaxOp : SpatOp<"vmax", []> {
let summary = "Element-wise max between two tensors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
#endif // SPATIAL_DIALECT_H