add PIM accelerator

This commit is contained in:
NiccoloN
2026-02-24 15:09:18 +01:00
parent b24a0df8d7
commit a6e928bdd7
67 changed files with 9109 additions and 1 deletions

View File

@@ -0,0 +1,2 @@
add_subdirectory(PIM)
add_subdirectory(Spatial)

View File

@@ -0,0 +1,15 @@
add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td)
add_onnx_mlir_library(PimOps
PimOps.cpp
Transforms/PimBufferizableOpInterface.cpp
DEPENDS
OMPimIncGen
LINK_LIBS PUBLIC
OMMlirDialects
MLIRIR
)

345
src/PIM/Dialect/PIM/Pim.td Normal file
View File

@@ -0,0 +1,345 @@
#ifndef PIM_DIALECT_H
#define PIM_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
def PimDialect : Dialect {
let name = "pim";
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
let cppNamespace = "::onnx_mlir::pim";
}
// Base class for Pim dialect operations. This operation inherits from the
// base `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class PimOp<string mnemonic, list<Trait> traits = []> :
Op<PimDialect, mnemonic, traits>;
def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
//===----------------------------------------------------------------------===//
// Communication operations
//===----------------------------------------------------------------------===//
def PimSendOp: PimOp<"send", []> {
let arguments = (ins
PimTensor: $src,
I32Attr: $size,
I32Attr: $targetCoreId
);
let assemblyFormat = [{
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
}];
}
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins
PimTensor: $dst,
I32Attr: $size,
I32Attr: $srcCoreId
);
let results = (outs
PimTensor: $out
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
}];
}
//===----------------------------------------------------------------------===//
// Core operations
//===----------------------------------------------------------------------===//
def PimCoreOp: PimOp<"core", [SingleBlock]> {
let regions = (region SizedRegion<1>:$body);
let arguments = (ins
Variadic<PimTensor>:$weights,
I32Attr: $coreId
);
let assemblyFormat = [{
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
}];
}
//===----------------------------------------------------------------------===//
// Memory Operations
//===----------------------------------------------------------------------===//
def PimConstantOp: PimOp<"constant", []> {
let description = [{
Allocate a constant value in global memory
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
PimTensor: $out
);
}
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from host memory into device memory
}];
let arguments = (ins
PimTensor: $deviceDst,
PimTensor: $hostSrc,
I32Attr: $deviceDstOffset,
I32Attr: $hostSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $deviceDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceDstMutable();
}
}];
let assemblyFormat = [{
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
}];
}
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from device memory into host memory
}];
let arguments = (ins
PimTensor: $hostDst,
PimTensor: $deviceSrc,
I32Attr: $hostDstOffset,
I32Attr: $deviceSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $hostDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostDstMutable();
}
}];
let assemblyFormat = [{
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
}];
}
//===----------------------------------------------------------------------===//
// Core.Compute operations
//===----------------------------------------------------------------------===//
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
Vector-matrix multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
let description = [{
Matrix-vector multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise max: c = max(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Apply filters to a tensor
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
PimTensor: $input,
PimTensor: $outBuf,
PimTensor: $accumBuf
);
let results = (outs
PimTensor: $outRes
);
let assemblyFormat = [{
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
}];
}
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Sum all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
}];
let arguments = (ins
PimTensor: $dividend,
PimTensor: $divisor,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise ReLU: c = max(a, 0)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise exp: c = exp(a)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimHaltOp: PimOp<"halt", [Terminator]> {
let description = [{
Halts the execution of the core
}];
let assemblyFormat = [{
attr-dict
}];
}
#endif // PIM_DIALECT_H

View File

@@ -0,0 +1,49 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
>();
}
#define POPULATE_DEPENDENCIES(OP_NAME) \
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVMaxOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp)
} // namespace pim
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"

View File

@@ -0,0 +1,18 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"

View File

@@ -0,0 +1,172 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace pim {
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc();
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
if (failed(deviceDstOpt))
return failure();
auto deviceDstMemRef = *deviceDstOpt;
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
if (failed(hostSrcOpt))
return failure();
auto hostSrcMemRef = *hostSrcOpt;
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceDstMemRef.getType(),
deviceDstMemRef,
hostSrcMemRef,
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
memCopyHostToDevOp.getHostSrcOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
};
struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst();
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
if (failed(globalDstOpt))
return failure();
auto globalDstMemRef = *globalDstOpt;
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
if (failed(localSrcOpt))
return failure();
auto localSrcMemRef = *localSrcOpt;
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
globalDstMemRef.getType(),
globalDstMemRef,
localSrcMemRef,
memCopyDevToHostOp.getHostDstOffsetAttr(),
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
Value readVal = uRead->get();
Value writeVal = uWrite->get();
if (writeVal != vmmOp.getOutBuf())
return false;
if (readVal == vmmOp.getVectorInput())
if (state.areEquivalentBufferizedValues(readVal, writeVal))
return true;
return false;
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
return success();
}
};
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
return success();
}
};
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
const AnalysisState& state,
ArrayRef<OpOperand*> opOperands) const {
return true;
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto vaddOp = cast<PimVAddOp>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
if (failed(aOpt))
return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
if (failed(bOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
});
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,14 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,15 @@
add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td)
add_onnx_mlir_library(SpatialOps
SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp
DEPENDS
OMSpatialIncGen
LINK_LIBS PUBLIC
MLIRIR
OMMlirDialects
)

View File

@@ -0,0 +1,355 @@
#ifndef SPATIAL_DIALECT_H
#define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
def SpatialDialect : Dialect {
let name = "spat";
let summary = "Dialect designed for deep learning computation in a spatial architecture";
let cppNamespace = "::onnx_mlir::spatial";
let useDefaultTypePrinterParser = 1;
}
class SpatOp<string mnemonic, list<Trait> traits = []> :
Op<SpatialDialect, mnemonic, traits>;
// TODO maybe remove and use AnyRankedTensor directly
def SpatTensor:
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<SpatialDialect, name, traits> {
let mnemonic = typeMnemonic;
}
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute operation, with constant weights already attached";
let arguments = (ins
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let assemblyFormat = [{
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
}];
}
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
let assemblyFormat = [{
$outputs attr-dict `:` type($outputs)
}];
}
//===----------------------------------------------------------------------===//
// Data movement operations
//===----------------------------------------------------------------------===//
def SpatChannelNewOp: SpatOp<"channel_new", []> {
let results = (outs
SpatChannelType:$new_channel
);
let builders = [
OpBuilder<(ins ), [{
$_state.addTypes(SpatChannelType());
}]>
];
let assemblyFormat = [{
attr-dict
}];
}
def SpatChannelSendOp: SpatOp<"channel_send", []> {
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
);
let assemblyFormat = [{
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
}];
}
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
let arguments = (ins
SpatChannelType: $channel
);
let results = (outs
SpatTensor: $data
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
);
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let arguments = (ins
SpatChannelType: $channel
);
let results = (outs
SpatTensor: $data
);
}
//===----------------------------------------------------------------------===//
// Math operations
//===----------------------------------------------------------------------===//
def SpatConstantOp: SpatOp<"constant", []> {
let description = [{
"Constant value, should be used for weights and biases"
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
SpatTensor: $out
);
}
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatVAddOp: SpatOp<"vadd", []> {
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
def SpatVMulOp: SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
def SpatVDivOp: SpatOp<"vdiv", []> {
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor:$a,
SpatTensor:$b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
//TODO: remove
def SpatVSDivOp: SpatOp<"vsdiv", []> {
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
let arguments = (ins
SpatTensor:$dividend,
SpatTensor:$divisor
);
let results = (outs
SpatTensor:$output
);
}
def SpatSumOp: SpatOp<"sum", []> {
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
let arguments = (ins
SpatTensor: $input
);
let results = (outs
SpatTensor:$output
);
}
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
}
def SpatReluOp: SpatOp<"relu", []> {
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
}
def SpatVMaxOp: SpatOp<"vmax", []> {
let summary = "Element-wise max function";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
let description = [{
Applies a variable number of crossbar weights to a single large image tensor tile,
producing a corresponding output tile. This essentially encapsulates a big for loop
over all pixels in the input tile, where each pixel is multiplied by all the weights
in the operation.
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
SpatTensor: $input
);
let results = (outs SpatTensor);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type(results)
}];
}
//===----------------------------------------------------------------------===//
// Other operations
//===----------------------------------------------------------------------===//
def SpatImgConcatOp: SpatOp<"img_concat", []> {
let summary = "Concatenate pixel tiles into a single image";
let description = [{
Concatenate pixel tiles into a single image:
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
The input tiles should be provided in a specific order:
start from the top left pixel,
then continue with the pixel on its right,
and once you finish the first row of pixels, go to the next row.
}];
let arguments = (ins
Variadic<SpatTensor>:$inputs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let extraClassDeclaration = [{
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
}];
}
#endif // SPATIAL_DIALECT_H

View File

@@ -0,0 +1,339 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void SpatialDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
>();
}
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 2
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
// Verify that the matrix shape is (N, M)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
// Verify that the vector shape is (M, 1)
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
// Verify that the output shape is (N, 1)
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 4
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
// Verify that the matrix shape is (N, M, 1, 1)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
// Verify that the vector shape is (1, M, 1, 1)
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
// Verify that the output shape is (1, N, 1, 1)
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
}
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
1. matrix: (N, M); vector: (M, 1); output: (N, 1)
2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1)
*/
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
else if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
else
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
1. vector: (1, N); matrix: (N, M); output: (1, M)
*/
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vector1 = vectorShape[0];
int64_t vectorN = vectorShape[1];
if (vectorN != N || vector1 != 1)
return emitError("vector shape must be (N, 1)");
int64_t output1 = outputShape[0];
int64_t outputM = outputShape[1];
if (outputM != M || output1 != 1)
return emitError("output shape must be (M, 1)");
return success();
}
LogicalResult SpatVAddOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
auto operands = getOperands();
// Check number of operands
if (img_w * img_h * channelTiles != operands.size())
return emitError("Number of operands does not match output image size");
// For each output pixel, check that the inputTiles have a correct shape
for (size_t x = 0; x < img_w; x++) {
for (size_t y = 0; y < img_h; y++) {
size_t channel_counts = 0;
for (size_t t = 0; t < channelTiles; t++) {
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
if (!inputShape)
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
// - CASE2: common case, the channel count is exactly the crossbarSize
if (t == channelTiles - 1 && channelTileRest != 0) {
if (inputChannels != channelTileRest)
return emitError("Invalid channel count for last tile of pixel");
}
else {
if (inputChannels != crossbarSize)
return emitError("Invalid channel count for some pixel tile");
}
channel_counts += inputChannels;
}
if (channel_counts != img_c)
emitError("Invalid number of channels for some pixel");
}
}
return success();
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
auto& block = getBody().front();
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size()) {
return emitError("ComputeOp must have same number of results as yieldOp "
"operands");
}
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
// Same type and compatible shape
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
return emitError("ComputeOp output must be of the same type as yieldOp "
"operand");
}
// Same encoding
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
return emitError("ComputeOp output must have the same encoding as "
"yieldOp operand");
}
}
else {
return emitError("ComputeOp output has an encoding while yieldOp "
"operand does not have one");
}
}
else {
// If result does not have an encoding, yield shouldn't either
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if "
"yieldOp operand has one");
}
}
}
// Check that each block argument is used
for (auto arg : block.getArguments())
if (arg.use_empty())
return emitError("ComputeOp block argument is not used");
return success();
}
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
assert(tile < channelTiles);
assert(x < img_w);
assert(y < img_h);
return operands[tile + x * channelTiles + y * img_w * channelTiles];
}
} // namespace spatial
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"

View File

@@ -0,0 +1,20 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
#define GET_TYPEDEF_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"

View File

@@ -0,0 +1,493 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace spatial {
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
auto resultShape = cast<ShapedType>(resultType);
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other
// Receive/Send operation. However, during bufferization, the first of the
// Receive/Send operation that is processed gets removed. As such, we need to
// "precompute" the coreId needed for the other op, and save it as attribute
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
if (failed(notOpUserOpt))
return failure();
Operation* notOpUser = *notOpUserOpt;
// Save the coreId for this op into the other op as attribute
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
}
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
// Input tensor to the compute OP are always read into its local memory
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensor to the compute OP are _never_ written into its local memory
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the compute OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
// Bufferize its block
auto& block = op->getRegion(0).front();
return bufferizeBlockSignature(&block, rewriter, options, state);
}
};
/*
* This can be used for operation that have a single argument, which is a
* variadic of tensors, and a single output with the same same shape
* Example: VAdd, VSub, VExp
*/
template <typename InterfaceName, typename OpTy, typename ToTy>
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor values into memref values
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
// Turn Tensor Operands into Memref Operands
SmallVector<Value> memrefOperands;
memrefOperands.reserve(op->getNumOperands());
for (auto operand : op->getOperands()) {
auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref))
return failure();
memrefOperands.push_back(*memref);
}
// TODO: Support addiction with more than 2 operands
if (memrefOperands.size() > 2) {
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
"with 1 or 2 operands, for now.");
return failure();
}
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
memrefOperands.push_back(outputTensor);
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
template <typename InterfaceName, typename OpTy, typename ToTy>
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor value into memref value
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(memrefOperandOpt))
return failure();
auto memrefOperand = *memrefOperandOpt;
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue =
rewriter
.create<ToTy>(
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
.getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel receive to pim.recv
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
if (failed(srcCoreId))
return failure();
Value newValue = rewriter
.create<pim::PimReceiveOp>(op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOut();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
if (failed(dstCoreId))
return failure();
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
op,
srcMemRef,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(dstCoreId.value()));
return success();
}
};
struct ChannelBroadcastReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel receive to pim.load using by creating a new global buffer
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
return failure();
}
// The first 'broadcast' operation creates the buffer just after the
// channelNewOp, while the other 'broadcast' operation need to find this
// buffer allocation just after the channelNewOp
Value bufferAllocation;
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
// Buffer already allocated, load from this buffer
bufferAllocation = allocOpAfterChannel;
}
else {
// Buffer was not allocated previously, allocate it after channelNewOp
rewriter.setInsertionPointAfter(channelNewOp);
bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
}
rewriter.setInsertionPoint(op);
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
outputTensor.getType(),
outputTensor,
bufferAllocation,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
return success();
}
};
struct ChannelBroadcastSendOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
return failure();
}
// The first 'broadcast' operation creates the buffer just after the
// channelNewOp, while the other 'broadcast' operation need to find this
// buffer allocation just after the channelNewOp
Value bufferAllocation;
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
// Buffer already allocated, load from this buffer
bufferAllocation = allocOpAfterChannel;
}
else {
// Buffer was not allocated previously, allocate it after channelNewOp
rewriter.setInsertionPointAfter(channelNewOp);
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
}
rewriter.setInsertionPoint(op);
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
return success();
}
};
struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
// One operand ($input) is read from. All other inputs are only written to.
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 0;
}
// One input ($accumBuf) is written to. All other inputs are only read.
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 2;
}
// No operands are aliased with any other operands.
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Bufferize the operation.
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
// Get the input tensor buffer.
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(inputBuffer))
return failure();
// Create a new buffer for the output tensor.
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
// Create a new buffer for the accumulation buffer.
// To do this, create a new allocation operation. Size must be axbx1x1,
// where axbxcxd is the size of the output tensor. Since the shape is
// different, we can't immediately use createEmptyFromType, we first need to
// create the shape of the accumulation buffer.
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
// Set the last two dimensions to 1.
accumShape[accumShape.size() - 1] = 1;
accumShape[accumShape.size() - 2] = 1;
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
// Bufferize the operation.
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
outputTensor.getType(),
weightIndices,
xKernelPositions,
yKernelPositions,
*inputBuffer,
outputTensor,
accumBuffer);
// Replace the operation with the bufferized value.
replaceOpWithBufferizedValues(rewriter, op, bufferized);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
});
}
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
});
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,16 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
} // namespace spatial
} // namespace onnx_mlir