480 lines
13 KiB
TableGen
480 lines
13 KiB
TableGen
#ifndef SPATIAL_DIALECT_H
|
|
#define SPATIAL_DIALECT_H
|
|
|
|
include "mlir/IR/OpBase.td"
|
|
include "mlir/IR/OpAsmInterface.td"
|
|
include "mlir/IR/BuiltinTypes.td"
|
|
include "mlir/IR/AttrTypeBase.td"
|
|
include "mlir/IR/RegionKindInterface.td"
|
|
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
|
|
|
def SpatialDialect : Dialect {
|
|
let name = "spat";
|
|
let summary = "Dialect designed for deep learning computation in a spatial architecture";
|
|
let cppNamespace = "::onnx_mlir::spatial";
|
|
}
|
|
|
|
class SpatOp<string mnemonic, list<Trait> traits = []> :
|
|
Op<SpatialDialect, mnemonic, traits>;
|
|
|
|
// TODO maybe remove and use AnyRankedTensor directly
|
|
def SpatTensor :
|
|
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Execution
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def SpatCompute : SpatOp<"compute",
|
|
[SingleBlock, AttrSizedOperandSegments,
|
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
|
let summary = "Compute region with attached constant weights";
|
|
|
|
let arguments = (ins
|
|
Variadic<SpatTensor>:$weights,
|
|
Variadic<SpatTensor>:$inputs
|
|
);
|
|
|
|
let results = (outs
|
|
Variadic<SpatTensor>:$outputs
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$body);
|
|
|
|
let extraClassDeclaration = [{
|
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
|
}];
|
|
|
|
let hasVerifier = 1;
|
|
let hasFolder = 1;
|
|
let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def SpatComputeBatch : SpatOp<"compute_batch",
|
|
[SingleBlock, AttrSizedOperandSegments,
|
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
|
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
|
|
|
let arguments = (ins
|
|
I32Attr:$laneCount,
|
|
Variadic<SpatTensor>:$weights,
|
|
Variadic<SpatTensor>:$inputs
|
|
);
|
|
|
|
let results = (outs
|
|
Variadic<SpatTensor>:$outputs
|
|
);
|
|
|
|
let regions = (region SizedRegion<1>:$body);
|
|
|
|
let extraClassDeclaration = [{
|
|
std::optional<::mlir::BlockArgument> getLaneArgument();
|
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
|
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
|
}];
|
|
|
|
let hasVerifier = 1;
|
|
let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def SpatInParallelOp : SpatOp<"in_parallel", [
|
|
Pure,
|
|
Terminator,
|
|
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
|
HasParent<"SpatComputeBatch">,
|
|
] # GraphRegionNoTerminator.traits> {
|
|
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
|
|
|
let regions = (region SizedRegion<1>:$region);
|
|
|
|
let hasCustomAssemblyFormat = 1;
|
|
let hasVerifier = 1;
|
|
|
|
let skipDefaultBuilders = 1;
|
|
let builders = [
|
|
OpBuilder<(ins)>,
|
|
];
|
|
|
|
let extraClassDeclaration = [{
|
|
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
|
|
::mlir::OpResult getParentResult(int64_t idx);
|
|
}];
|
|
}
|
|
|
|
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
|
let summary = "Yield results from a compute region";
|
|
|
|
let arguments = (ins
|
|
Variadic<SpatTensor>:$outputs
|
|
);
|
|
|
|
let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def SpatExtractRowsOp : SpatOp<"extract_rows", []> {
|
|
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
Variadic<SpatTensor>:$outputs
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def SpatConcatOp : SpatOp<"concat", []> {
|
|
let summary = "Concatenate tensors with compact Spatial operand syntax";
|
|
|
|
let arguments = (ins
|
|
I64Attr:$axis,
|
|
Variadic<SpatTensor>:$inputs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Communication
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
|
let summary = "Send a tensor through a logical channel";
|
|
|
|
let arguments = (ins
|
|
Index:$channelId,
|
|
Index:$sourceCoreId,
|
|
Index:$targetCoreId,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
|
let summary = "Receive a tensor from a logical channel";
|
|
|
|
let arguments = (ins
|
|
Index:$channelId,
|
|
Index:$sourceCoreId,
|
|
Index:$targetCoreId
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
|
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
|
|
|
let arguments = (ins
|
|
Variadic<Index>:$channelIds,
|
|
Variadic<Index>:$sourceCoreIds,
|
|
Variadic<Index>:$targetCoreIds,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let assemblyFormat = [{
|
|
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
|
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
|
|
|
let arguments = (ins
|
|
Variadic<Index>:$channelIds,
|
|
Variadic<Index>:$sourceCoreIds,
|
|
Variadic<Index>:$targetCoreIds
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let assemblyFormat = [{
|
|
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
|
let summary = "Send per-lane tensors through logical channels in a batch body";
|
|
|
|
let arguments = (ins
|
|
Variadic<Index>:$channelIds,
|
|
Variadic<Index>:$sourceCoreIds,
|
|
Variadic<Index>:$targetCoreIds,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let assemblyFormat = [{
|
|
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
|
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
|
|
|
let arguments = (ins
|
|
Variadic<Index>:$channelIds,
|
|
Variadic<Index>:$sourceCoreIds,
|
|
Variadic<Index>:$targetCoreIds,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let assemblyFormat = [{
|
|
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
|
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
|
|
|
let arguments = (ins
|
|
Variadic<Index>:$channelIds,
|
|
Variadic<Index>:$sourceCoreIds,
|
|
Variadic<Index>:$targetCoreIds
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let assemblyFormat = [{
|
|
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
|
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
|
|
|
let arguments = (ins
|
|
Variadic<Index>:$channelIds,
|
|
Variadic<Index>:$sourceCoreIds,
|
|
Variadic<Index>:$targetCoreIds
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
let assemblyFormat = [{
|
|
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
|
}];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Math
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def SpatVMMOp : SpatOp<"wvmm", []> {
|
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$weight,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatMVMOp : SpatOp<"Wmvm", []> {
|
|
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$weight,
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVAddOp : SpatOp<"vadd", []> {
|
|
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$lhs,
|
|
SpatTensor:$rhs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVMulOp : SpatOp<"vmul", []> {
|
|
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$lhs,
|
|
SpatTensor:$rhs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatSumOp : SpatOp<"sum", []> {
|
|
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVAvgOp : SpatOp<"vavg", []> {
|
|
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
|
let summary = "Element-wise sigmoid activation";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatSoftmaxOp : SpatOp<"softmax", []> {
|
|
let summary = "Softmax over the full input tensor slice";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatReluOp : SpatOp<"relu", []> {
|
|
let summary = "Element-wise ReLU activation";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$input
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
|
}];
|
|
}
|
|
|
|
def SpatVMaxOp : SpatOp<"vmax", []> {
|
|
let summary = "Element-wise max between two tensors";
|
|
|
|
let arguments = (ins
|
|
SpatTensor:$lhs,
|
|
SpatTensor:$rhs
|
|
);
|
|
|
|
let results = (outs
|
|
SpatTensor:$output
|
|
);
|
|
|
|
let hasVerifier = 1;
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
|
}];
|
|
}
|
|
|
|
#endif // SPATIAL_DIALECT_H
|