standardize spatial and pim dialects

remove old unused stuff
This commit is contained in:
NiccoloN
2026-03-23 21:21:31 +01:00
parent 0478d979ff
commit 93e20c1dfc
18 changed files with 693 additions and 1519 deletions

View File

@@ -20,51 +20,18 @@ class PimOp<string mnemonic, list<Trait> traits = []> :
def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
// Communication
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
def PimSendOp: PimOp<"send", []> {
let arguments = (ins
PimTensor: $src,
I32Attr: $size,
I32Attr: $targetCoreId
);
let assemblyFormat = [{
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
}];
}
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins
PimTensor: $dst,
I32Attr: $size,
I32Attr: $srcCoreId
);
let results = (outs
PimTensor: $out
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
}];
}
// Core
def PimCoreOp: PimOp<"core", [SingleBlock]> {
def PimCoreOp : PimOp<"core", [SingleBlock]> {
let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body);
let arguments = (ins
Variadic<PimTensor>:$weights,
I32Attr: $coreId
I32Attr:$coreId
);
let assemblyFormat = [{
@@ -72,412 +39,443 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
}];
}
// Memory
def PimConstantOp: PimOp<"constant", []> {
let description = [{
Allocate a constant value in global memory
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
PimTensor: $out
);
}
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from host memory into device memory
}];
let arguments = (ins
PimTensor: $deviceDst,
PimTensor: $hostSrc,
I32Attr: $deviceDstOffset,
I32Attr: $hostSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $deviceDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceDstMutable();
}
}];
let assemblyFormat = [{
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
}];
}
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from device memory into host memory
}];
let arguments = (ins
PimTensor: $hostDst,
PimTensor: $deviceSrc,
I32Attr: $hostDstOffset,
I32Attr: $deviceSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $hostDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostDstMutable();
}
}];
let assemblyFormat = [{
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
}];
}
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from and to the same memory
}];
let arguments = (ins
PimTensor: $dst,
PimTensor: $src,
I32Attr: $dstOffset,
I32Attr: $srcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $dstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
}];
}
// Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
Vector-matrix multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
let description = [{
Matrix-vector multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
let description = [{
Element-wise subtraction: c = a - b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
let description = [{
Element-wise multiplication: c = a * b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
let description = [{
Element-wise max: c = max(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Dot product: c = dot(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Apply filters to a tensor
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
PimTensor: $input,
PimTensor: $outBuf,
PimTensor: $accumBuf
);
let results = (outs
PimTensor: $outRes
);
let assemblyFormat = [{
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
}];
}
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Sum all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Average all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise ReLU: c = max(a, 0)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise tanh activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise sigmoid activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimHaltOp: PimOp<"halt", [Terminator]> {
let description = [{
Halts the execution of the core
}];
def PimHaltOp : PimOp<"halt", [Terminator]> {
let summary = "Halt execution of the core";
let assemblyFormat = [{
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
def PimSendOp : PimOp<"send", []> {
let summary = "Send a tensor to another core";
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
I32Attr:$targetCoreId
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
}];
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
I32Attr:$sourceCoreId
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
}];
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
let arguments = (ins
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceTargetMutable();
}
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
}];
}
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory";
let arguments = (ins
PimTensor:$hostTarget,
PimTensor:$deviceSource,
I32Attr:$hostTargetOffset,
I32Attr:$deviceSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostTargetMutable();
}
}];
let assemblyFormat = [{
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
}];
}
def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region within the same memory space";
let arguments = (ins
PimTensor:$target,
PimTensor:$source,
I32Attr:$targetOffset,
I32Attr:$sourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getTargetMutable();
}
}];
let assemblyFormat = [{
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def PimTransposeOp : PimOp<"transpose", [DestinationStyleOpInterface]> {
let summary = "Transpose a matrix";
let arguments = (ins
PimTensor:$input,
I64ArrayAttr:$permutation,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let summary = "Vector-matrix multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
let summary = "Matrix-vector multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
let summary = "Element-wise addition: c = a + b";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVSubOp : PimOp<"vvsub", [DestinationStyleOpInterface]> {
let summary = "Element-wise subtraction: c = a - b";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVMulOp : PimOp<"vvmul", [DestinationStyleOpInterface]> {
let summary = "Element-wise multiplication: c = a * b";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVMaxOp : PimOp<"vvmax", [DestinationStyleOpInterface]> {
let summary = "Element-wise max: c = max(a, b)";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
let summary = "Dot product: c = dot(a, b)";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
let summary = "Reduce all elements to a single value";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
let summary = "Average all elements into a single value";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVReluOp : PimOp<"vrelu", [DestinationStyleOpInterface]> {
let summary = "Element-wise ReLU: c = max(a, 0)";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVTanhOp : PimOp<"vtanh", [DestinationStyleOpInterface]> {
let summary = "Element-wise tanh activation";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
#endif // PIM_DIALECT_H

View File

@@ -25,19 +25,6 @@ void PimDialect::initialize() {
>();
}
#define POPULATE_DEPENDENCIES(OP_NAME) \
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVVDMulOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVAvgOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVTanhOp)
POPULATE_DEPENDENCIES(PimVSigmOp)
} // namespace pim
} // namespace onnx_mlir

View File

@@ -30,7 +30,7 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
.getOutput();
}
struct MemCopyHostToDevOpInterface
@@ -40,26 +40,26 @@ struct MemCopyHostToDevOpInterface
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc();
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
auto hostSource = memCopyHostToDevOp.getHostSource();
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
if (failed(deviceDstOpt))
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state);
if (failed(deviceTargetOpt))
return failure();
auto deviceDstMemRef = *deviceDstOpt;
auto deviceTargetMemRef = *deviceTargetOpt;
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
if (failed(hostSrcOpt))
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state);
if (failed(hostSourceOpt))
return failure();
auto hostSrcMemRef = *hostSrcOpt;
auto hostSourceMemRef = *hostSourceOpt;
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceDstMemRef.getType(),
deviceDstMemRef,
hostSrcMemRef,
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
memCopyHostToDevOp.getHostSrcOffsetAttr(),
deviceTargetMemRef.getType(),
deviceTargetMemRef,
hostSourceMemRef,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
@@ -73,25 +73,25 @@ struct MemCopyDevToHostOpInterface
BufferizationState& state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst();
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
if (failed(globalDstOpt))
auto hostTarget = memCopyDevToHostOp.getHostTarget();
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state);
if (failed(hostTargetOpt))
return failure();
auto globalDstMemRef = *globalDstOpt;
auto hostTargetMemRef = *hostTargetOpt;
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
if (failed(localSrcOpt))
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state);
if (failed(deviceSourceOpt))
return failure();
auto localSrcMemRef = *localSrcOpt;
auto deviceSourceMemRef = *deviceSourceOpt;
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
globalDstMemRef.getType(),
globalDstMemRef,
localSrcMemRef,
memCopyDevToHostOp.getHostDstOffsetAttr(),
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
hostTargetMemRef.getType(),
hostTargetMemRef,
deviceSourceMemRef,
memCopyDevToHostOp.getHostTargetOffsetAttr(),
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
@@ -109,16 +109,16 @@ struct TransposeOpBufferizeInterface
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt);
return success();
}
};
@@ -132,9 +132,9 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
auto vmmOp = cast<PimVMMOp>(op);
Value readVal = uRead->get();
Value writeVal = uWrite->get();
if (writeVal != vmmOp.getOutBuf())
if (writeVal != vmmOp.getOutputBuffer())
return false;
if (readVal == vmmOp.getVectorInput())
if (readVal == vmmOp.getInput())
if (state.areEquivalentBufferizedValues(readVal, writeVal))
return true;
return false;
@@ -146,16 +146,16 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
return success();
}
};
@@ -171,16 +171,16 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
return success();
}
};
@@ -203,22 +203,23 @@ struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<B
BufferizationState& state) const {
auto binaryOp = cast<OpTy>(op);
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
if (failed(aOpt))
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state);
if (failed(lhsOpt))
return failure();
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
if (failed(bOpt))
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state);
if (failed(rhsOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
replaceOpWithNewBufferizedOp<OpTy>(
rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt);
return success();
}
};

View File

@@ -16,4 +16,5 @@ def memrefCopyToPimMemCopyOp : Pat<
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION

View File

@@ -16,7 +16,7 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
Op<SpatialDialect, mnemonic, traits>;
// TODO maybe remove and use AnyRankedTensor directly
def SpatTensor:
def SpatTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute operation, with constant weights already attached";
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins
Variadic<SpatTensor>:$weights,
@@ -49,7 +53,9 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
}];
}
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
}
//===----------------------------------------------------------------------===//
// Data movement operations
// Communication
//===----------------------------------------------------------------------===//
def SpatChannelNewOp: SpatOp<"channel_new", []> {
def SpatChannelNewOp : SpatOp<"channel_new", []> {
let summary = "Create a new virtual channel";
let results = (outs
SpatChannelType:$new_channel
SpatChannelType:$channel
);
let builders = [
@@ -79,108 +87,74 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
}];
}
def SpatChannelSendOp: SpatOp<"channel_send", []> {
def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a channel";
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
SpatChannelType:$channel,
SpatTensor:$input
);
let assemblyFormat = [{
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a channel";
let arguments = (ins
SpatChannelType: $channel
SpatChannelType:$channel
);
let results = (outs
SpatTensor: $data
SpatTensor:$output
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let summary = "Broadcast a tensor through a shared channel buffer";
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
SpatChannelType:$channel,
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let summary = "Receive a tensor from a shared channel buffer";
let arguments = (ins
SpatChannelType: $channel
SpatChannelType:$channel
);
let results = (outs
SpatTensor: $data
SpatTensor:$output
);
}
//===----------------------------------------------------------------------===//
// Math operations
//===----------------------------------------------------------------------===//
def SpatConstantOp: SpatOp<"constant", []> {
let description = [{
"Constant value, should be used for weights and biases"
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
SpatTensor: $out
);
}
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatVAddOp: SpatOp<"vadd", []> {
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
@@ -190,76 +164,68 @@ def SpatVAddOp: SpatOp<"vadd", []> {
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMulOp: SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
def SpatVDivOp: SpatOp<"vdiv", []> {
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
SpatTensor:$a,
SpatTensor:$b
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
//TODO: remove
def SpatVSDivOp: SpatOp<"vsdiv", []> {
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$dividend,
SpatTensor:$divisor
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatSumOp: SpatOp<"sum", []> {
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
def SpatVMulOp : SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor: $input
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
def SpatSumOp : SpatOp<"sum", []> {
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
@@ -267,9 +233,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatReluOp: SpatOp<"relu", []> {
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
SpatTensor:$input
);
@@ -277,68 +249,34 @@ def SpatReluOp: SpatOp<"relu", []> {
let results = (outs
SpatTensor:$output
);
}
def SpatVMaxOp: SpatOp<"vmax", []> {
let summary = "Element-wise max function";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
let description = [{
Applies a variable number of crossbar weights to a single large image tensor tile,
producing a corresponding output tile. This essentially encapsulates a big for loop
over all pixels in the input tile, where each pixel is multiplied by all the weights
in the operation.
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
SpatTensor: $input
);
let results = (outs SpatTensor);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type(results)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Other operations
//===----------------------------------------------------------------------===//
def SpatImgConcatOp: SpatOp<"img_concat", []> {
let summary = "Concatenate pixel tiles into a single image";
let description = [{
Concatenate pixel tiles into a single image:
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
The input tiles should be provided in a specific order:
start from the top left pixel,
then continue with the pixel on its right,
and once you finish the first row of pixels, go to the next row.
}];
def SpatReluOp : SpatOp<"relu", []> {
let summary = "Element-wise ReLU activation";
let arguments = (ins
Variadic<SpatTensor>:$inputs
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMaxOp : SpatOp<"vmax", []> {
let summary = "Element-wise max between two tensors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
@@ -347,9 +285,9 @@ def SpatImgConcatOp: SpatOp<"img_concat", []> {
let hasVerifier = 1;
let extraClassDeclaration = [{
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
#endif // SPATIAL_DIALECT_H
#endif // SPATIAL_DIALECT_H

View File

@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
auto operands = getOperands();
// Check number of operands
if (img_w * img_h * channelTiles != operands.size())
return emitError("Number of operands does not match output image size");
// For each output pixel, check that the inputTiles have a correct shape
for (size_t x = 0; x < img_w; x++) {
for (size_t y = 0; y < img_h; y++) {
size_t channel_counts = 0;
for (size_t t = 0; t < channelTiles; t++) {
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
if (!inputShape)
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
// - CASE2: common case, the channel count is exactly the crossbarSize
if (t == channelTiles - 1 && channelTileRest != 0) {
if (inputChannels != channelTileRest)
return emitError("Invalid channel count for last tile of pixel");
}
else {
if (inputChannels != crossbarSize)
return emitError("Invalid channel count for some pixel tile");
}
channel_counts += inputChannels;
}
if (channel_counts != img_c)
emitError("Invalid number of channels for some pixel");
}
}
return success();
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
return success();
}
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
assert(tile < channelTiles);
assert(x < img_w);
assert(y < img_h);
return operands[tile + x * channelTiles + y * img_w * channelTiles];
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -53,7 +53,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
.getOutput();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
@@ -202,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
memrefOperands.push_back(outputTensor);
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -243,7 +243,7 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutRes();
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -289,7 +289,7 @@ struct ChannelReceiveOpInterface
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOut();
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -403,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
return success();
}
@@ -490,84 +490,6 @@ struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, S
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
// One operand ($input) is read from. All other inputs are only written to.
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 0;
}
// One input ($accumBuf) is written to. All other inputs are only read.
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 2;
}
// No operands are aliased with any other operands.
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Bufferize the operation.
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Get the input tensor buffer.
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(inputBuffer))
return failure();
// Create a new buffer for the output tensor.
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
// Create a new buffer for the accumulation buffer.
// To do this, create a new allocation operation. Size must be axbx1x1,
// where axbxcxd is the size of the output tensor. Since the shape is
// different, we can't immediately use createEmptyFromType, we first need to
// create the shape of the accumulation buffer.
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
// Set the last two dimensions to 1.
accumShape[accumShape.size() - 1] = 1;
accumShape[accumShape.size() - 2] = 1;
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
// Bufferize the operation.
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
weightIndices,
xKernelPositions,
yKernelPositions,
*inputBuffer,
outputTensor,
accumBuffer);
// Replace the operation with the bufferized value.
replaceOpWithBufferizedValues(rewriter, op, bufferized);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
@@ -580,7 +502,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
});
}