standardize spatial and pim dialects
remove old unused stuff
This commit is contained in:
@@ -20,51 +20,18 @@ class PimOp<string mnemonic, list<Trait> traits = []> :
|
||||
def PimTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimSendOp: PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
PimTensor: $src,
|
||||
I32Attr: $size,
|
||||
I32Attr: $targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let arguments = (ins
|
||||
PimTensor: $dst,
|
||||
I32Attr: $size,
|
||||
I32Attr: $srcCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $out
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
|
||||
}];
|
||||
}
|
||||
|
||||
// Core
|
||||
|
||||
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
||||
let summary = "Execute a block on a PIM core";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$weights,
|
||||
I32Attr: $coreId
|
||||
I32Attr:$coreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -72,412 +39,443 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||
}];
|
||||
}
|
||||
|
||||
// Memory
|
||||
|
||||
def PimConstantOp: PimOp<"constant", []> {
|
||||
let description = [{
|
||||
Allocate a constant value in global memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr: $value,
|
||||
BoolAttr: $shouldAllocate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $out
|
||||
);
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from host memory into device memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $deviceDst,
|
||||
PimTensor: $hostSrc,
|
||||
I32Attr: $deviceDstOffset,
|
||||
I32Attr: $hostSrcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $deviceDstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from device memory into host memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $hostDst,
|
||||
PimTensor: $deviceSrc,
|
||||
I32Attr: $hostDstOffset,
|
||||
I32Attr: $deviceSrcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $hostDstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getHostDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from and to the same memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dst,
|
||||
PimTensor: $src,
|
||||
I32Attr: $dstOffset,
|
||||
I32Attr: $srcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $dstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
// Algebra
|
||||
|
||||
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix transpose
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $data,
|
||||
I64ArrayAttr: $perms,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Vector-matrix multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
PimTensor: $vectorInput,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix-vector multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
PimTensor: $vectorInput,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise addition: c = a + b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise subtraction: c = a - b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise max: c = max(a, b)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Dot product: c = dot(a, b)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Apply filters to a tensor
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64ArrayAttr: $weightIndices,
|
||||
I64ArrayAttr: $xKernelPositions,
|
||||
I64ArrayAttr: $yKernelPositions,
|
||||
PimTensor: $input,
|
||||
PimTensor: $outBuf,
|
||||
PimTensor: $accumBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
|
||||
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Sum all elements into a single one
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Average all elements into a single one
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise ReLU: c = max(a, 0)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise tanh activation
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise sigmoid activation
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimHaltOp: PimOp<"halt", [Terminator]> {
|
||||
let description = [{
|
||||
Halts the execution of the core
|
||||
}];
|
||||
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||
let summary = "Halt execution of the core";
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimSendOp : PimOp<"send", []> {
|
||||
let summary = "Send a tensor to another core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
I32Attr:$targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive a tensor from another core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
I32Attr:$sourceCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceTargetMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$hostTarget,
|
||||
PimTensor:$deviceSource,
|
||||
I32Attr:$hostTargetOffset,
|
||||
I32Attr:$deviceSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getHostTargetMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region within the same memory space";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$target,
|
||||
PimTensor:$source,
|
||||
I32Attr:$targetOffset,
|
||||
I32Attr:$sourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getTargetMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimTransposeOp : PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||
let summary = "Transpose a matrix";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I64ArrayAttr:$permutation,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Vector-matrix multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Matrix-vector multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise addition: c = a + b";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$lhs,
|
||||
PimTensor:$rhs,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVSubOp : PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise subtraction: c = a - b";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$lhs,
|
||||
PimTensor:$rhs,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMulOp : PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$lhs,
|
||||
PimTensor:$rhs,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMaxOp : PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise max: c = max(a, b)";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$lhs,
|
||||
PimTensor:$rhs,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
|
||||
let summary = "Dot product: c = dot(a, b)";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$lhs,
|
||||
PimTensor:$rhs,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
|
||||
let summary = "Reduce all elements to a single value";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
||||
let summary = "Average all elements into a single value";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVReluOp : PimOp<"vrelu", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise ReLU: c = max(a, 0)";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVTanhOp : PimOp<"vtanh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise tanh activation";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise sigmoid activation";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
|
||||
@@ -25,19 +25,6 @@ void PimDialect::initialize() {
|
||||
>();
|
||||
}
|
||||
|
||||
#define POPULATE_DEPENDENCIES(OP_NAME) \
|
||||
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
|
||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
||||
}
|
||||
|
||||
POPULATE_DEPENDENCIES(PimVVDMulOp)
|
||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
||||
POPULATE_DEPENDENCIES(PimSumOp)
|
||||
POPULATE_DEPENDENCIES(PimVAvgOp)
|
||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
||||
POPULATE_DEPENDENCIES(PimVTanhOp)
|
||||
POPULATE_DEPENDENCIES(PimVSigmOp)
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
@@ -40,26 +40,26 @@ struct MemCopyHostToDevOpInterface
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
||||
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
|
||||
auto hostSrc = memCopyHostToDevOp.getHostSrc();
|
||||
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
|
||||
auto hostSource = memCopyHostToDevOp.getHostSource();
|
||||
|
||||
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
|
||||
if (failed(deviceDstOpt))
|
||||
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state);
|
||||
if (failed(deviceTargetOpt))
|
||||
return failure();
|
||||
auto deviceDstMemRef = *deviceDstOpt;
|
||||
auto deviceTargetMemRef = *deviceTargetOpt;
|
||||
|
||||
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
|
||||
if (failed(hostSrcOpt))
|
||||
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state);
|
||||
if (failed(hostSourceOpt))
|
||||
return failure();
|
||||
auto hostSrcMemRef = *hostSrcOpt;
|
||||
auto hostSourceMemRef = *hostSourceOpt;
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceDstMemRef.getType(),
|
||||
deviceDstMemRef,
|
||||
hostSrcMemRef,
|
||||
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSrcOffsetAttr(),
|
||||
deviceTargetMemRef.getType(),
|
||||
deviceTargetMemRef,
|
||||
hostSourceMemRef,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -73,25 +73,25 @@ struct MemCopyDevToHostOpInterface
|
||||
BufferizationState& state) const {
|
||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||
|
||||
auto globalDst = memCopyDevToHostOp.getHostDst();
|
||||
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
|
||||
if (failed(globalDstOpt))
|
||||
auto hostTarget = memCopyDevToHostOp.getHostTarget();
|
||||
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state);
|
||||
if (failed(hostTargetOpt))
|
||||
return failure();
|
||||
auto globalDstMemRef = *globalDstOpt;
|
||||
auto hostTargetMemRef = *hostTargetOpt;
|
||||
|
||||
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
|
||||
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
|
||||
if (failed(localSrcOpt))
|
||||
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
|
||||
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state);
|
||||
if (failed(deviceSourceOpt))
|
||||
return failure();
|
||||
auto localSrcMemRef = *localSrcOpt;
|
||||
auto deviceSourceMemRef = *deviceSourceOpt;
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||
memCopyDevToHostOp,
|
||||
globalDstMemRef.getType(),
|
||||
globalDstMemRef,
|
||||
localSrcMemRef,
|
||||
memCopyDevToHostOp.getHostDstOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
|
||||
hostTargetMemRef.getType(),
|
||||
hostTargetMemRef,
|
||||
deviceSourceMemRef,
|
||||
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
||||
memCopyDevToHostOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -109,16 +109,16 @@ struct TransposeOpBufferizeInterface
|
||||
BufferizationState& state) const {
|
||||
auto transposeOp = cast<PimTransposeOp>(op);
|
||||
|
||||
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
|
||||
if (failed(dataOpt))
|
||||
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -132,9 +132,9 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
Value readVal = uRead->get();
|
||||
Value writeVal = uWrite->get();
|
||||
if (writeVal != vmmOp.getOutBuf())
|
||||
if (writeVal != vmmOp.getOutputBuffer())
|
||||
return false;
|
||||
if (readVal == vmmOp.getVectorInput())
|
||||
if (readVal == vmmOp.getInput())
|
||||
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
||||
return true;
|
||||
return false;
|
||||
@@ -146,16 +146,16 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
||||
BufferizationState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
|
||||
if (failed(vectorInputOpt))
|
||||
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -171,16 +171,16 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
BufferizationState& state) const {
|
||||
auto mvmOp = cast<PimMVMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
|
||||
if (failed(vectorInputOpt))
|
||||
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -203,22 +203,23 @@ struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<B
|
||||
BufferizationState& state) const {
|
||||
auto binaryOp = cast<OpTy>(op);
|
||||
|
||||
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
|
||||
if (failed(aOpt))
|
||||
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state);
|
||||
if (failed(lhsOpt))
|
||||
return failure();
|
||||
|
||||
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
|
||||
if (failed(bOpt))
|
||||
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state);
|
||||
if (failed(rhsOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
|
||||
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
|
||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
|
||||
replaceOpWithNewBufferizedOp<OpTy>(
|
||||
rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,4 +16,5 @@ def memrefCopyToPimMemCopyOp : Pat<
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
|
||||
@@ -16,7 +16,7 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<SpatialDialect, mnemonic, traits>;
|
||||
|
||||
// TODO maybe remove and use AnyRankedTensor directly
|
||||
def SpatTensor:
|
||||
def SpatTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
|
||||
@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||
let summary = "Virtual channel type";
|
||||
}
|
||||
|
||||
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute operation, with constant weights already attached";
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$weights,
|
||||
@@ -49,7 +53,9 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
let summary = "Yield results from a compute region";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Data movement operations
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
||||
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
||||
let summary = "Create a new virtual channel";
|
||||
|
||||
let results = (outs
|
||||
SpatChannelType:$new_channel
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let builders = [
|
||||
@@ -79,108 +87,74 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendOp: SpatOp<"channel_send", []> {
|
||||
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||
let summary = "Send a tensor through a channel";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel,
|
||||
SpatTensor: $data
|
||||
SpatChannelType:$channel,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
|
||||
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
|
||||
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
let summary = "Receive a tensor from a channel";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $data
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||
let summary = "Broadcast a tensor through a shared channel buffer";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel,
|
||||
SpatTensor: $data
|
||||
SpatChannelType:$channel,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||
let summary = "Receive a tensor from a shared channel buffer";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $data
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatConstantOp: SpatOp<"constant", []> {
|
||||
let description = [{
|
||||
"Constant value, should be used for weights and biases"
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr: $value,
|
||||
BoolAttr: $shouldAllocate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $out
|
||||
);
|
||||
}
|
||||
|
||||
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
|
||||
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
SpatTensor:$vector
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
SpatTensor:$vector
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
def SpatVAddOp: SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -190,76 +164,68 @@ def SpatVAddOp: SpatOp<"vadd", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMulOp: SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
//let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVDivOp: SpatOp<"vdiv", []> {
|
||||
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$a,
|
||||
SpatTensor:$b
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
//let hasVerifier = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//TODO: remove
|
||||
def SpatVSDivOp: SpatOp<"vsdiv", []> {
|
||||
|
||||
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
|
||||
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$dividend,
|
||||
SpatTensor:$divisor
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSumOp: SpatOp<"sum", []> {
|
||||
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
|
||||
def SpatVMulOp : SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $input
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
||||
def SpatSumOp : SpatOp<"sum", []> {
|
||||
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
@@ -267,9 +233,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatReluOp: SpatOp<"relu", []> {
|
||||
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||
let summary = "Element-wise sigmoid activation";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
@@ -277,68 +249,34 @@ def SpatReluOp: SpatOp<"relu", []> {
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def SpatVMaxOp: SpatOp<"vmax", []> {
|
||||
|
||||
let summary = "Element-wise max function";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
|
||||
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
|
||||
let description = [{
|
||||
Applies a variable number of crossbar weights to a single large image tensor tile,
|
||||
producing a corresponding output tile. This essentially encapsulates a big for loop
|
||||
over all pixels in the input tile, where each pixel is multiplied by all the weights
|
||||
in the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64ArrayAttr: $weightIndices,
|
||||
I64ArrayAttr: $xKernelPositions,
|
||||
I64ArrayAttr: $yKernelPositions,
|
||||
SpatTensor: $input
|
||||
);
|
||||
let results = (outs SpatTensor);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input) `->` type(results)
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
||||
|
||||
let summary = "Concatenate pixel tiles into a single image";
|
||||
|
||||
let description = [{
|
||||
Concatenate pixel tiles into a single image:
|
||||
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
|
||||
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
|
||||
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
|
||||
|
||||
The input tiles should be provided in a specific order:
|
||||
start from the top left pixel,
|
||||
then continue with the pixel on its right,
|
||||
and once you finish the first row of pixels, go to the next row.
|
||||
}];
|
||||
def SpatReluOp : SpatOp<"relu", []> {
|
||||
let summary = "Element-wise ReLU activation";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$inputs
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMaxOp : SpatOp<"vmax", []> {
|
||||
let summary = "Element-wise max between two tensors";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -347,9 +285,9 @@ def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // SPATIAL_DIALECT_H
|
||||
#endif // SPATIAL_DIALECT_H
|
||||
|
||||
@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Two possible accepted shapes:
|
||||
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Accepted shape:
|
||||
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatImgConcatOp::verify() {
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
size_t channelTileRest = img_c % crossbarSize;
|
||||
|
||||
auto operands = getOperands();
|
||||
|
||||
// Check number of operands
|
||||
if (img_w * img_h * channelTiles != operands.size())
|
||||
return emitError("Number of operands does not match output image size");
|
||||
|
||||
// For each output pixel, check that the inputTiles have a correct shape
|
||||
for (size_t x = 0; x < img_w; x++) {
|
||||
for (size_t y = 0; y < img_h; y++) {
|
||||
size_t channel_counts = 0;
|
||||
for (size_t t = 0; t < channelTiles; t++) {
|
||||
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
||||
if (!inputShape)
|
||||
return emitError("Invalid input type, must be ShapedType");
|
||||
|
||||
// N == W == H == 1
|
||||
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
||||
|
||||
size_t inputChannels = getImageChannel(inputShape);
|
||||
|
||||
// Check the number of channels in this tile are correct:
|
||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
||||
// - CASE2: common case, the channel count is exactly the crossbarSize
|
||||
if (t == channelTiles - 1 && channelTileRest != 0) {
|
||||
if (inputChannels != channelTileRest)
|
||||
return emitError("Invalid channel count for last tile of pixel");
|
||||
}
|
||||
else {
|
||||
if (inputChannels != crossbarSize)
|
||||
return emitError("Invalid channel count for some pixel tile");
|
||||
}
|
||||
|
||||
channel_counts += inputChannels;
|
||||
}
|
||||
|
||||
if (channel_counts != img_c)
|
||||
emitError("Invalid number of channels for some pixel");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
||||
auto operands = getOperands();
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
|
||||
assert(tile < channelTiles);
|
||||
assert(x < img_w);
|
||||
assert(y < img_h);
|
||||
|
||||
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
@@ -202,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -243,7 +243,7 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
||||
cast<OpTy>(op).getWeightIndexAttr(),
|
||||
memrefOperand,
|
||||
outputTensor)
|
||||
.getOutRes();
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -289,7 +289,7 @@ struct ChannelReceiveOpInterface
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOut();
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -403,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -490,84 +490,6 @@ struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, S
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
// Create a new bufferizable op interface for the apply filters operation.
|
||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||
|
||||
// One operand ($input) is read from. All other inputs are only written to.
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
|
||||
// Operand 0: $input
|
||||
// Operand 1: $outBuf
|
||||
// Operand 2: $accumBuf
|
||||
return opOperand.getOperandNumber() == 0;
|
||||
}
|
||||
|
||||
// One input ($accumBuf) is written to. All other inputs are only read.
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
|
||||
// Operand 0: $input
|
||||
// Operand 1: $outBuf
|
||||
// Operand 2: $accumBuf
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// No operands are aliased with any other operands.
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Bufferize the operation.
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Get the input tensor buffer.
|
||||
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
|
||||
if (failed(inputBuffer))
|
||||
return failure();
|
||||
|
||||
// Create a new buffer for the output tensor.
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
// Create a new buffer for the accumulation buffer.
|
||||
// To do this, create a new allocation operation. Size must be axbx1x1,
|
||||
// where axbxcxd is the size of the output tensor. Since the shape is
|
||||
// different, we can't immediately use createEmptyFromType, we first need to
|
||||
// create the shape of the accumulation buffer.
|
||||
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
|
||||
|
||||
// Set the last two dimensions to 1.
|
||||
accumShape[accumShape.size() - 1] = 1;
|
||||
accumShape[accumShape.size() - 2] = 1;
|
||||
|
||||
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
|
||||
|
||||
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
|
||||
|
||||
// Bufferize the operation.
|
||||
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
|
||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
||||
|
||||
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
weightIndices,
|
||||
xKernelPositions,
|
||||
yKernelPositions,
|
||||
*inputBuffer,
|
||||
outputTensor,
|
||||
accumBuffer);
|
||||
|
||||
// Replace the operation with the bufferized value.
|
||||
replaceOpWithBufferizedValues(rewriter, op, bufferized);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||
@@ -580,7 +502,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user