294 lines
6.5 KiB
TableGen
294 lines
6.5 KiB
TableGen
#ifndef SPATIAL_DIALECT_H
|
|
#define SPATIAL_DIALECT_H
|
|
|
|
include "mlir/IR/OpBase.td"
|
|
include "mlir/IR/BuiltinTypes.td"
|
|
include "mlir/IR/AttrTypeBase.td"
|
|
|
|
def SpatialDialect : Dialect {
|
|
let name = "spat";
|
|
let summary = "Dialect designed for deep learning computation in a spatial architecture";
|
|
let cppNamespace = "::onnx_mlir::spatial";
|
|
let useDefaultTypePrinterParser = 1;
|
|
}
|
|
|
|
class SpatOp<string mnemonic, list<Trait> traits = []> :
|
|
Op<SpatialDialect, mnemonic, traits>;
|
|
|
|
// TODO maybe remove and use AnyRankedTensor directly
|
|
def SpatTensor :
|
|
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
|
|
|
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
|
|
: TypeDef<SpatialDialect, name, traits> {
|
|
let mnemonic = typeMnemonic;
|
|
}
|
|
|
|
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
|
let summary = "Virtual channel type";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Execution
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
|
let summary = "Compute region with attached constant weights";
|
|
|
|
let arguments = (ins
|
|
Variadic<SpatTensor>:$weights,
|
|
Variadic<SpatTensor>:$inputs
|
|
);
|
|
|
|
let results = (outs
|
|
Variadic<SpatTensor>:$outputs
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$body);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
|
|
}];
|
|
}
|
|
|
|
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
|
let summary = "Yield results from a compute region";
|
|
|
|
let arguments = (ins
|
|
Variadic<SpatTensor>:$outputs
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$outputs attr-dict `:` type($outputs)
|
|
}];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Communication
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
|
let summary = "Create a new virtual channel";
|
|
|
|
let results = (outs
|
|
SpatChannelType:$channel
|
|
);
|
|
|
|
let builders = [
|
|
OpBuilder<(ins ), [{
|
|
$_state.addTypes(SpatChannelType());
|
|
}]>
|
|
];
|
|
|
|
let assemblyFormat = [{
|
|
attr-dict
|
|
}];
|
|
}
|
|
|
|
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
|
let summary = "Send a tensor through a channel";
|
|
|
|
let arguments = (ins
|
|
SpatChannelType:$channel,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
|
}];
|
|
}
|
|
|
|
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
|
let summary = "Receive a tensor from a channel";
|
|
|
|
let arguments = (ins
|
|
SpatChannelType:$channel
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
|
}];
|
|
}
|
|
|
|
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
|
let summary = "Broadcast a tensor through a shared channel buffer";
|
|
|
|
let arguments = (ins
|
|
SpatChannelType:$channel,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
|
}];
|
|
}
|
|
|
|
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
|
let summary = "Receive a tensor from a shared channel buffer";
|
|
|
|
let arguments = (ins
|
|
SpatChannelType:$channel
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
|
}];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Math
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
|
|
|
let arguments = (ins
|
|
I32Attr:$weightIndex,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
|
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
|
|
|
let arguments = (ins
|
|
I32Attr:$weightIndex,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVAddOp : SpatOp<"vadd", []> {
|
|
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$lhs,
|
|
SpatTensor:$rhs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVMulOp : SpatOp<"vmul", []> {
|
|
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$lhs,
|
|
SpatTensor:$rhs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatSumOp : SpatOp<"sum", []> {
|
|
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
|
let summary = "Element-wise sigmoid activation";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatReluOp : SpatOp<"relu", []> {
|
|
let summary = "Element-wise ReLU activation";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVMaxOp : SpatOp<"vmax", []> {
|
|
let summary = "Element-wise max between two tensors";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$lhs,
|
|
SpatTensor:$rhs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
#endif // SPATIAL_DIALECT_H
|