add PIM accelerator
This commit is contained in:
2
src/PIM/Dialect/CMakeLists.txt
Normal file
2
src/PIM/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(PIM)
|
||||
add_subdirectory(Spatial)
|
||||
15
src/PIM/Dialect/PIM/CMakeLists.txt
Normal file
15
src/PIM/Dialect/PIM/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
add_onnx_mlir_dialect(Pim pim)
|
||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
|
||||
|
||||
add_onnx_mlir_library(PimOps
|
||||
PimOps.cpp
|
||||
Transforms/PimBufferizableOpInterface.cpp
|
||||
|
||||
DEPENDS
|
||||
OMPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMMlirDialects
|
||||
MLIRIR
|
||||
)
|
||||
345
src/PIM/Dialect/PIM/Pim.td
Normal file
345
src/PIM/Dialect/PIM/Pim.td
Normal file
@@ -0,0 +1,345 @@
|
||||
#ifndef PIM_DIALECT_H
|
||||
#define PIM_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
|
||||
|
||||
def PimDialect : Dialect {
|
||||
let name = "pim";
|
||||
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
|
||||
let cppNamespace = "::onnx_mlir::pim";
|
||||
}
|
||||
|
||||
// Base class for Pim dialect operations. This operation inherits from the
|
||||
// base `Op` class in OpBase.td, and provides:
|
||||
// * The parent dialect of the operation.
|
||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||
// * A list of traits for the operation.
|
||||
class PimOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<PimDialect, mnemonic, traits>;
|
||||
|
||||
def PimTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimSendOp: PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
PimTensor: $src,
|
||||
I32Attr: $size,
|
||||
I32Attr: $targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let arguments = (ins
|
||||
PimTensor: $dst,
|
||||
I32Attr: $size,
|
||||
I32Attr: $srcCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $out
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$weights,
|
||||
I32Attr: $coreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memory Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimConstantOp: PimOp<"constant", []> {
|
||||
let description = [{
|
||||
Allocate a constant value in global memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr: $value,
|
||||
BoolAttr: $shouldAllocate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $out
|
||||
);
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from host memory into device memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $deviceDst,
|
||||
PimTensor: $hostSrc,
|
||||
I32Attr: $deviceDstOffset,
|
||||
I32Attr: $hostSrcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $deviceDstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from device memory into host memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $hostDst,
|
||||
PimTensor: $deviceSrc,
|
||||
I32Attr: $hostDstOffset,
|
||||
I32Attr: $deviceSrcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $hostDstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getHostDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core.Compute operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Vector-matrix multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
PimTensor: $vectorInput,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix-vector multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
PimTensor: $vectorInput,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise addition: c = a + b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise max: c = max(a, b)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Apply filters to a tensor
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64ArrayAttr: $weightIndices,
|
||||
I64ArrayAttr: $xKernelPositions,
|
||||
I64ArrayAttr: $yKernelPositions,
|
||||
PimTensor: $input,
|
||||
PimTensor: $outBuf,
|
||||
PimTensor: $accumBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
|
||||
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Sum all elements into a single one
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dividend,
|
||||
PimTensor: $divisor,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise ReLU: c = max(a, 0)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise exp: c = exp(a)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimHaltOp: PimOp<"halt", [Terminator]> {
|
||||
let description = [{
|
||||
Halts the execution of the core
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
49
src/PIM/Dialect/PIM/PimOps.cpp
Normal file
49
src/PIM/Dialect/PIM/PimOps.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
||||
|
||||
>();
|
||||
}
|
||||
|
||||
#define POPULATE_DEPENDENCIES(OP_NAME) \
|
||||
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
|
||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
||||
}
|
||||
|
||||
POPULATE_DEPENDENCIES(PimVMaxOp)
|
||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
||||
POPULATE_DEPENDENCIES(PimSumOp)
|
||||
POPULATE_DEPENDENCIES(PimVSDivOp)
|
||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
||||
POPULATE_DEPENDENCIES(PimVExpOp)
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
||||
18
src/PIM/Dialect/PIM/PimOps.hpp
Normal file
18
src/PIM/Dialect/PIM/PimOps.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"
|
||||
172
src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp
Normal file
172
src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp
Normal file
@@ -0,0 +1,172 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
||||
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
|
||||
auto hostSrc = memCopyHostToDevOp.getHostSrc();
|
||||
|
||||
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
|
||||
if (failed(deviceDstOpt))
|
||||
return failure();
|
||||
auto deviceDstMemRef = *deviceDstOpt;
|
||||
|
||||
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
|
||||
if (failed(hostSrcOpt))
|
||||
return failure();
|
||||
auto hostSrcMemRef = *hostSrcOpt;
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceDstMemRef.getType(),
|
||||
deviceDstMemRef,
|
||||
hostSrcMemRef,
|
||||
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSrcOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyDevToHostOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||
|
||||
auto globalDst = memCopyDevToHostOp.getHostDst();
|
||||
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
|
||||
if (failed(globalDstOpt))
|
||||
return failure();
|
||||
auto globalDstMemRef = *globalDstOpt;
|
||||
|
||||
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
|
||||
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
|
||||
if (failed(localSrcOpt))
|
||||
return failure();
|
||||
auto localSrcMemRef = *localSrcOpt;
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||
memCopyDevToHostOp,
|
||||
globalDstMemRef.getType(),
|
||||
globalDstMemRef,
|
||||
localSrcMemRef,
|
||||
memCopyDevToHostOp.getHostDstOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
|
||||
memCopyDevToHostOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
Value readVal = uRead->get();
|
||||
Value writeVal = uWrite->get();
|
||||
if (writeVal != vmmOp.getOutBuf())
|
||||
return false;
|
||||
if (readVal == vmmOp.getVectorInput())
|
||||
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
|
||||
if (failed(vectorInputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto mvmOp = cast<PimMVMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
|
||||
if (failed(vectorInputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||
const AnalysisState& state,
|
||||
ArrayRef<OpOperand*> opOperands) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto vaddOp = cast<PimVAddOp>(op);
|
||||
|
||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
||||
if (failed(aOpt))
|
||||
return failure();
|
||||
|
||||
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
|
||||
if (failed(bOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
15
src/PIM/Dialect/Spatial/CMakeLists.txt
Normal file
15
src/PIM/Dialect/Spatial/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
add_onnx_mlir_dialect(Spatial spat)
|
||||
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
|
||||
add_onnx_mlir_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/SpatialBufferizableOpInterface.cpp
|
||||
|
||||
DEPENDS
|
||||
OMSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
OMMlirDialects
|
||||
)
|
||||
355
src/PIM/Dialect/Spatial/Spatial.td
Normal file
355
src/PIM/Dialect/Spatial/Spatial.td
Normal file
@@ -0,0 +1,355 @@
|
||||
#ifndef SPATIAL_DIALECT_H
|
||||
#define SPATIAL_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
|
||||
def SpatialDialect : Dialect {
|
||||
let name = "spat";
|
||||
let summary = "Dialect designed for deep learning computation in a spatial architecture";
|
||||
let cppNamespace = "::onnx_mlir::spatial";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
class SpatOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<SpatialDialect, mnemonic, traits>;
|
||||
|
||||
// TODO maybe remove and use AnyRankedTensor directly
|
||||
def SpatTensor:
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
|
||||
: TypeDef<SpatialDialect, name, traits> {
|
||||
let mnemonic = typeMnemonic;
|
||||
}
|
||||
|
||||
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||
let summary = "Virtual channel type";
|
||||
}
|
||||
|
||||
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute operation, with constant weights already attached";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$weights,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$outputs attr-dict `:` type($outputs)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Data movement operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
||||
let results = (outs
|
||||
SpatChannelType:$new_channel
|
||||
);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins ), [{
|
||||
$_state.addTypes(SpatChannelType());
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendOp: SpatOp<"channel_send", []> {
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel,
|
||||
SpatTensor: $data
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $data
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel,
|
||||
SpatTensor: $data
|
||||
);
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $data
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatConstantOp: SpatOp<"constant", []> {
|
||||
let description = [{
|
||||
"Constant value, should be used for weights and biases"
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr: $value,
|
||||
BoolAttr: $shouldAllocate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $out
|
||||
);
|
||||
}
|
||||
|
||||
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
|
||||
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
SpatTensor:$vector
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
SpatTensor:$vector
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
def SpatVAddOp: SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMulOp: SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
//let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVDivOp: SpatOp<"vdiv", []> {
|
||||
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$a,
|
||||
SpatTensor:$b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
//let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//TODO: remove
|
||||
def SpatVSDivOp: SpatOp<"vsdiv", []> {
|
||||
|
||||
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$dividend,
|
||||
SpatTensor:$divisor
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def SpatSumOp: SpatOp<"sum", []> {
|
||||
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def SpatReluOp: SpatOp<"relu", []> {
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def SpatVMaxOp: SpatOp<"vmax", []> {
|
||||
|
||||
let summary = "Element-wise max function";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
|
||||
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
|
||||
let description = [{
|
||||
Applies a variable number of crossbar weights to a single large image tensor tile,
|
||||
producing a corresponding output tile. This essentially encapsulates a big for loop
|
||||
over all pixels in the input tile, where each pixel is multiplied by all the weights
|
||||
in the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64ArrayAttr: $weightIndices,
|
||||
I64ArrayAttr: $xKernelPositions,
|
||||
I64ArrayAttr: $yKernelPositions,
|
||||
SpatTensor: $input
|
||||
);
|
||||
let results = (outs SpatTensor);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input) `->` type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
||||
|
||||
let summary = "Concatenate pixel tiles into a single image";
|
||||
|
||||
let description = [{
|
||||
Concatenate pixel tiles into a single image:
|
||||
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
|
||||
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
|
||||
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
|
||||
|
||||
The input tiles should be provided in a specific order:
|
||||
start from the top left pixel,
|
||||
then continue with the pixel on its right,
|
||||
and once you finish the first row of pixels, go to the next row.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // SPATIAL_DIALECT_H
|
||||
339
src/PIM/Dialect/Spatial/SpatialOps.cpp
Normal file
339
src/PIM/Dialect/Spatial/SpatialOps.cpp
Normal file
@@ -0,0 +1,339 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void SpatialDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
|
||||
|
||||
>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
|
||||
|
||||
>();
|
||||
}
|
||||
|
||||
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
|
||||
// Verify that the matrix, vector and output shapes have rank 2
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitter->emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
// Verify that the matrix shape is (N, M)
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
|
||||
// Verify that the vector shape is (M, 1)
|
||||
int64_t vectorM = vectorShape[0];
|
||||
int64_t vector1 = vectorShape[1];
|
||||
if (vectorM != M || vector1 != 1)
|
||||
return emitter->emitError("vector shape must be (M, 1)");
|
||||
|
||||
// Verify that the output shape is (N, 1)
|
||||
int64_t outputN = outputShape[0];
|
||||
int64_t output1 = outputShape[1];
|
||||
if (outputN != N || output1 != 1)
|
||||
return emitter->emitError("output shape must be (N, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
|
||||
// Verify that the matrix, vector and output shapes have rank 4
|
||||
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
|
||||
return emitter->emitError("matrix, vector and output must have rank 4");
|
||||
|
||||
// Verify that the matrix shape is (N, M, 1, 1)
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
int64_t matrix1First = matrixShape[2];
|
||||
int64_t matrix1Second = matrixShape[3];
|
||||
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
|
||||
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
|
||||
|
||||
// Verify that the vector shape is (1, M, 1, 1)
|
||||
int64_t vector1First = vectorShape[0];
|
||||
int64_t vectorM = vectorShape[1];
|
||||
int64_t vector1Second = vectorShape[2];
|
||||
int64_t vector1Third = vectorShape[3];
|
||||
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
|
||||
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
|
||||
// This is ok, it was caused by the simplification of the concat error
|
||||
}
|
||||
else {
|
||||
return emitter->emitError("vector shape must be (1, M, 1, 1)");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the output shape is (1, N, 1, 1)
|
||||
int64_t output1First = outputShape[0];
|
||||
int64_t outputN = outputShape[1];
|
||||
int64_t output1Second = outputShape[2];
|
||||
int64_t output1Third = outputShape[3];
|
||||
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
|
||||
return emitter->emitError("output shape must be (1, N, 1, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
||||
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
|
||||
if (wcomputeOp)
|
||||
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
|
||||
|
||||
if (coreOp)
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Two possible accepted shapes:
|
||||
1. matrix: (N, M); vector: (M, 1); output: (N, 1)
|
||||
2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1)
|
||||
*/
|
||||
|
||||
if (matrixShape.size() == 2)
|
||||
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
|
||||
else if (matrixShape.size() == 4)
|
||||
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
|
||||
else
|
||||
return emitError("matrix rank must be 2 or 4");
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Accepted shape:
|
||||
1. vector: (1, N); matrix: (N, M); output: (1, M)
|
||||
*/
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
|
||||
int64_t vector1 = vectorShape[0];
|
||||
int64_t vectorN = vectorShape[1];
|
||||
if (vectorN != N || vector1 != 1)
|
||||
return emitError("vector shape must be (N, 1)");
|
||||
|
||||
int64_t output1 = outputShape[0];
|
||||
int64_t outputM = outputShape[1];
|
||||
if (outputM != M || output1 != 1)
|
||||
return emitError("output shape must be (M, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatVAddOp::verify() {
|
||||
// At least two operands
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatVMaxOp::verify() {
|
||||
// At least two operands
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatImgConcatOp::verify() {
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
size_t channelTileRest = img_c % crossbarSize;
|
||||
|
||||
auto operands = getOperands();
|
||||
|
||||
// Check number of operands
|
||||
if (img_w * img_h * channelTiles != operands.size())
|
||||
return emitError("Number of operands does not match output image size");
|
||||
|
||||
// For each output pixel, check that the inputTiles have a correct shape
|
||||
for (size_t x = 0; x < img_w; x++) {
|
||||
for (size_t y = 0; y < img_h; y++) {
|
||||
size_t channel_counts = 0;
|
||||
for (size_t t = 0; t < channelTiles; t++) {
|
||||
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
||||
if (!inputShape)
|
||||
return emitError("Invalid input type, must be ShapedType");
|
||||
|
||||
// N == W == H == 1
|
||||
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
|
||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
||||
|
||||
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
|
||||
|
||||
// Check the number of channels in this tile are correct:
|
||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
||||
// - CASE2: common case, the channel count is exactly the crossbarSize
|
||||
if (t == channelTiles - 1 && channelTileRest != 0) {
|
||||
if (inputChannels != channelTileRest)
|
||||
return emitError("Invalid channel count for last tile of pixel");
|
||||
}
|
||||
else {
|
||||
if (inputChannels != crossbarSize)
|
||||
return emitError("Invalid channel count for some pixel tile");
|
||||
}
|
||||
|
||||
channel_counts += inputChannels;
|
||||
}
|
||||
|
||||
if (channel_counts != img_c)
|
||||
emitError("Invalid number of channels for some pixel");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
auto& block = getBody().front();
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("ComputeOp must have a single yield operation");
|
||||
|
||||
auto resultTypes = getResultTypes();
|
||||
auto yieldTypes = yieldOp->getOperandTypes();
|
||||
if (resultTypes.size() != yieldTypes.size()) {
|
||||
return emitError("ComputeOp must have same number of results as yieldOp "
|
||||
"operands");
|
||||
}
|
||||
|
||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||
auto resultType = std::get<0>(it);
|
||||
auto yieldType = std::get<1>(it);
|
||||
|
||||
// Same type and compatible shape
|
||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
|
||||
return emitError("ComputeOp output must be of the same type as yieldOp "
|
||||
"operand");
|
||||
}
|
||||
|
||||
// Same encoding
|
||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
|
||||
return emitError("ComputeOp output must have the same encoding as "
|
||||
"yieldOp operand");
|
||||
}
|
||||
}
|
||||
else {
|
||||
return emitError("ComputeOp output has an encoding while yieldOp "
|
||||
"operand does not have one");
|
||||
}
|
||||
}
|
||||
else {
|
||||
// If result does not have an encoding, yield shouldn't either
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
return emitError("ComputeOp output must not have an encoding if "
|
||||
"yieldOp operand has one");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check that each block argument is used
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
||||
auto operands = getOperands();
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
|
||||
assert(tile < channelTiles);
|
||||
assert(x < img_w);
|
||||
assert(y < img_h);
|
||||
|
||||
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
|
||||
20
src/PIM/Dialect/Spatial/SpatialOps.hpp
Normal file
20
src/PIM/Dialect/Spatial/SpatialOps.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.hpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
|
||||
@@ -0,0 +1,493 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
|
||||
auto resultShape = cast<ShapedType>(resultType);
|
||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||
|
||||
// Alloc an output memref
|
||||
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
// This function requires the existence of ChannelNewOp and the other
|
||||
// Receive/Send operation. However, during bufferization, the first of the
|
||||
// Receive/Send operation that is processed gets removed. As such, we need to
|
||||
// "precompute" the coreId needed for the other op, and save it as attribute
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId)
|
||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
||||
|
||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
|
||||
if (failed(notOpUserOpt))
|
||||
return failure();
|
||||
Operation* notOpUser = *notOpUserOpt;
|
||||
|
||||
// Save the coreId for this op into the other op as attribute
|
||||
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
|
||||
|
||||
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
|
||||
}
|
||||
|
||||
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
|
||||
|
||||
// Input tensor to the compute OP are always read into its local memory
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensor to the compute OP are _never_ written into its local memory
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the compute OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
// Bufferize its block
|
||||
|
||||
auto& block = op->getRegion(0).front();
|
||||
|
||||
return bufferizeBlockSignature(&block, rewriter, options, state);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This can be used for operation that have a single argument, which is a
|
||||
* variadic of tensors, and a single output with the same same shape
|
||||
* Example: VAdd, VSub, VExp
|
||||
*/
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor values into memref values
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
|
||||
// Turn Tensor Operands into Memref Operands
|
||||
SmallVector<Value> memrefOperands;
|
||||
memrefOperands.reserve(op->getNumOperands());
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(*memref);
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
if (memrefOperands.size() > 2) {
|
||||
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
|
||||
"with 1 or 2 operands, for now.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor value into memref value
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
if (failed(memrefOperandOpt))
|
||||
return failure();
|
||||
auto memrefOperand = *memrefOperandOpt;
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
Value newValue =
|
||||
rewriter
|
||||
.create<ToTy>(
|
||||
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
|
||||
.getOutRes();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.recv
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = rewriter
|
||||
.create<pim::PimReceiveOp>(op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOut();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
|
||||
if (failed(dstCoreId))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
|
||||
op,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(dstCoreId.value()));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.load using by creating a new global buffer
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The first 'broadcast' operation creates the buffer just after the
|
||||
// channelNewOp, while the other 'broadcast' operation need to find this
|
||||
// buffer allocation just after the channelNewOp
|
||||
Value bufferAllocation;
|
||||
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
|
||||
// Buffer already allocated, load from this buffer
|
||||
bufferAllocation = allocOpAfterChannel;
|
||||
}
|
||||
else {
|
||||
// Buffer was not allocated previously, allocate it after channelNewOp
|
||||
rewriter.setInsertionPointAfter(channelNewOp);
|
||||
bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
bufferAllocation,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastSendOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The first 'broadcast' operation creates the buffer just after the
|
||||
// channelNewOp, while the other 'broadcast' operation need to find this
|
||||
// buffer allocation just after the channelNewOp
|
||||
Value bufferAllocation;
|
||||
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
|
||||
// Buffer already allocated, load from this buffer
|
||||
bufferAllocation = allocOpAfterChannel;
|
||||
}
|
||||
else {
|
||||
// Buffer was not allocated previously, allocate it after channelNewOp
|
||||
rewriter.setInsertionPointAfter(channelNewOp);
|
||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||
|
||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||
|
||||
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
|
||||
|
||||
// Create a new bufferizable op interface for the apply filters operation.
|
||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||
|
||||
// One operand ($input) is read from. All other inputs are only written to.
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
|
||||
// Operand 0: $input
|
||||
// Operand 1: $outBuf
|
||||
// Operand 2: $accumBuf
|
||||
return opOperand.getOperandNumber() == 0;
|
||||
}
|
||||
|
||||
// One input ($accumBuf) is written to. All other inputs are only read.
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
|
||||
// Operand 0: $input
|
||||
// Operand 1: $outBuf
|
||||
// Operand 2: $accumBuf
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// No operands are aliased with any other operands.
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Bufferize the operation.
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
|
||||
// Get the input tensor buffer.
|
||||
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
|
||||
if (failed(inputBuffer))
|
||||
return failure();
|
||||
|
||||
// Create a new buffer for the output tensor.
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
// Create a new buffer for the accumulation buffer.
|
||||
// To do this, create a new allocation operation. Size must be axbx1x1,
|
||||
// where axbxcxd is the size of the output tensor. Since the shape is
|
||||
// different, we can't immediately use createEmptyFromType, we first need to
|
||||
// create the shape of the accumulation buffer.
|
||||
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
|
||||
|
||||
// Set the last two dimensions to 1.
|
||||
accumShape[accumShape.size() - 1] = 1;
|
||||
accumShape[accumShape.size() - 2] = 1;
|
||||
|
||||
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
|
||||
|
||||
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
|
||||
|
||||
// Bufferize the operation.
|
||||
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
|
||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
||||
|
||||
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
weightIndices,
|
||||
xKernelPositions,
|
||||
yKernelPositions,
|
||||
*inputBuffer,
|
||||
outputTensor,
|
||||
accumBuffer);
|
||||
|
||||
// Replace the operation with the bufferized value.
|
||||
replaceOpWithBufferizedValues(rewriter, op, bufferized);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user