Files
Raptor/src/PIM/Dialect/Spatial/Spatial.td
T
2026-06-24 15:52:07 +02:00

489 lines
13 KiB
TableGen

#ifndef SPATIAL_DIALECT_H
#define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def SpatialDialect : Dialect {
let name = "spat";
let summary = "Dialect designed for deep learning computation in a spatial architecture";
let cppNamespace = "::onnx_mlir::spatial";
}
class SpatOp<string mnemonic, list<Trait> traits = []> :
Op<SpatialDialect, mnemonic, traits>;
// TODO maybe remove and use AnyRankedTensor directly
def SpatTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
let arguments = (ins
I32Attr:$laneCount,
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
def SpatInParallelOp : SpatOp<"in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful Spatial compute batches";
let regions = (region SizedRegion<1>:$region);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins)>,
];
let extraClassDeclaration = [{
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
let hasCustomAssemblyFormat = 1;
}
def SpatExtractRowsOp : SpatOp<"extract_rows", []> {
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatConcatOp : SpatOp<"concat", []> {
let summary = "Concatenate tensors with compact Spatial operand syntax";
let arguments = (ins
I64Attr:$axis,
Variadic<SpatTensor>:$inputs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Planning
//===----------------------------------------------------------------------===//
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
let arguments = (ins
SpatTensor:$input,
SpatTensor:$weight,
Optional<SpatTensor>:$bias,
DenseI64ArrayAttr:$pads,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$dilations,
I64Attr:$group,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
let summary = "Layout-aware ReLU planning op";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
let summary = "Passive logical-to-physical layout selection record";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$physicalLayout,
DenseI64ArrayAttr:$fragmentOffsets,
DenseI64ArrayAttr:$fragmentSizes,
StrAttr:$indexMap
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
let summary = "Explicit layout conversion or materialization barrier";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$sourcePhysicalLayout,
StrAttr:$targetPhysicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a logical channel";
let arguments = (ins
Index:$channelId,
Index:$sourceCoreId,
Index:$targetCoreId,
SpatTensor:$input
);
let assemblyFormat = [{
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
}];
}
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a logical channel";
let arguments = (ins
Index:$channelId,
Index:$sourceCoreId,
Index:$targetCoreId
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def SpatVMMOp : SpatOp<"wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
SpatTensor:$weight,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}];
}
def SpatVVDMulOp : SpatOp<"vvdmul", []> {
let summary = "Dot product between two runtime vectors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVSubOp : SpatOp<"vsub", []> {
let summary = "Element-wise subtraction between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVMulOp : SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVAvgOp : SpatOp<"vavg", []> {
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatSoftmaxOp : SpatOp<"softmax", []> {
let summary = "Softmax over the full input tensor slice";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatReluOp : SpatOp<"relu", []> {
let summary = "Element-wise ReLU activation";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMaxOp : SpatOp<"vmax", []> {
let summary = "Element-wise max between two tensors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
#endif // SPATIAL_DIALECT_H