standardize spatial and pim dialects

remove old unused stuff
This commit is contained in:
NiccoloN
2026-03-23 21:21:31 +01:00
parent 0478d979ff
commit 93e20c1dfc
18 changed files with 693 additions and 1519 deletions

View File

@@ -16,7 +16,7 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
Op<SpatialDialect, mnemonic, traits>;
// TODO maybe remove and use AnyRankedTensor directly
def SpatTensor:
def SpatTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute operation, with constant weights already attached";
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins
Variadic<SpatTensor>:$weights,
@@ -49,7 +53,9 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
}];
}
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
}
//===----------------------------------------------------------------------===//
// Data movement operations
// Communication
//===----------------------------------------------------------------------===//
def SpatChannelNewOp: SpatOp<"channel_new", []> {
def SpatChannelNewOp : SpatOp<"channel_new", []> {
let summary = "Create a new virtual channel";
let results = (outs
SpatChannelType:$new_channel
SpatChannelType:$channel
);
let builders = [
@@ -79,108 +87,74 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
}];
}
def SpatChannelSendOp: SpatOp<"channel_send", []> {
def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a channel";
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
SpatChannelType:$channel,
SpatTensor:$input
);
let assemblyFormat = [{
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a channel";
let arguments = (ins
SpatChannelType: $channel
SpatChannelType:$channel
);
let results = (outs
SpatTensor: $data
SpatTensor:$output
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let summary = "Broadcast a tensor through a shared channel buffer";
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
SpatChannelType:$channel,
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let summary = "Receive a tensor from a shared channel buffer";
let arguments = (ins
SpatChannelType: $channel
SpatChannelType:$channel
);
let results = (outs
SpatTensor: $data
SpatTensor:$output
);
}
//===----------------------------------------------------------------------===//
// Math operations
//===----------------------------------------------------------------------===//
def SpatConstantOp: SpatOp<"constant", []> {
let description = [{
"Constant value, should be used for weights and biases"
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
SpatTensor: $out
);
}
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatVAddOp: SpatOp<"vadd", []> {
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
@@ -190,76 +164,68 @@ def SpatVAddOp: SpatOp<"vadd", []> {
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMulOp: SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
def SpatVDivOp: SpatOp<"vdiv", []> {
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
SpatTensor:$a,
SpatTensor:$b
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
//TODO: remove
def SpatVSDivOp: SpatOp<"vsdiv", []> {
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$dividend,
SpatTensor:$divisor
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatSumOp: SpatOp<"sum", []> {
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
def SpatVMulOp : SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor: $input
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
def SpatSumOp : SpatOp<"sum", []> {
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
@@ -267,9 +233,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatReluOp: SpatOp<"relu", []> {
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
SpatTensor:$input
);
@@ -277,68 +249,34 @@ def SpatReluOp: SpatOp<"relu", []> {
let results = (outs
SpatTensor:$output
);
}
def SpatVMaxOp: SpatOp<"vmax", []> {
let summary = "Element-wise max function";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
let description = [{
Applies a variable number of crossbar weights to a single large image tensor tile,
producing a corresponding output tile. This essentially encapsulates a big for loop
over all pixels in the input tile, where each pixel is multiplied by all the weights
in the operation.
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
SpatTensor: $input
);
let results = (outs SpatTensor);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type(results)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Other operations
//===----------------------------------------------------------------------===//
def SpatImgConcatOp: SpatOp<"img_concat", []> {
let summary = "Concatenate pixel tiles into a single image";
let description = [{
Concatenate pixel tiles into a single image:
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
The input tiles should be provided in a specific order:
start from the top left pixel,
then continue with the pixel on its right,
and once you finish the first row of pixels, go to the next row.
}];
def SpatReluOp : SpatOp<"relu", []> {
let summary = "Element-wise ReLU activation";
let arguments = (ins
Variadic<SpatTensor>:$inputs
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMaxOp : SpatOp<"vmax", []> {
let summary = "Element-wise max between two tensors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
@@ -347,9 +285,9 @@ def SpatImgConcatOp: SpatOp<"img_concat", []> {
let hasVerifier = 1;
let extraClassDeclaration = [{
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
#endif // SPATIAL_DIALECT_H
#endif // SPATIAL_DIALECT_H