add constant folding and verification pass for pim host operations
better validation scripts output big refactors
This commit is contained in:
16
src/PIM/Dialect/Pim/CMakeLists.txt
Normal file
16
src/PIM/Dialect/Pim/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_onnx_mlir_dialect(Pim pim)
|
||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
|
||||
add_onnx_mlir_library(PimOps
|
||||
PimOps.hpp
|
||||
PimOps.cpp
|
||||
|
||||
DEPENDS
|
||||
OMPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMMlirDialects
|
||||
MLIRIR
|
||||
)
|
||||
391
src/PIM/Dialect/Pim/Pim.td
Normal file
391
src/PIM/Dialect/Pim/Pim.td
Normal file
@@ -0,0 +1,391 @@
|
||||
#ifndef PIM_DIALECT_H
|
||||
#define PIM_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
|
||||
|
||||
def PimDialect : Dialect {
|
||||
let name = "pim";
|
||||
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
|
||||
let cppNamespace = "::onnx_mlir::pim";
|
||||
}
|
||||
|
||||
class PimOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<PimDialect, mnemonic, traits>;
|
||||
|
||||
def PimTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
// Communication
|
||||
|
||||
def PimSendOp: PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
PimTensor: $src,
|
||||
I32Attr: $size,
|
||||
I32Attr: $targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let arguments = (ins
|
||||
PimTensor: $dst,
|
||||
I32Attr: $size,
|
||||
I32Attr: $srcCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $out
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
|
||||
}];
|
||||
}
|
||||
|
||||
// Core
|
||||
|
||||
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$weights,
|
||||
I32Attr: $coreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
// Memory
|
||||
|
||||
def PimConstantOp: PimOp<"constant", []> {
|
||||
let description = [{
|
||||
Allocate a constant value in global memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr: $value,
|
||||
BoolAttr: $shouldAllocate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $out
|
||||
);
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from host memory into device memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $deviceDst,
|
||||
PimTensor: $hostSrc,
|
||||
I32Attr: $deviceDstOffset,
|
||||
I32Attr: $hostSrcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $deviceDstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from device memory into host memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $hostDst,
|
||||
PimTensor: $deviceSrc,
|
||||
I32Attr: $hostDstOffset,
|
||||
I32Attr: $deviceSrcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $hostDstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getHostDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from and to the same memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dst,
|
||||
PimTensor: $src,
|
||||
I32Attr: $dstOffset,
|
||||
I32Attr: $srcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $dstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
// Algebra
|
||||
|
||||
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix transpose
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $data,
|
||||
I64ArrayAttr: $perms,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Vector-matrix multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
PimTensor: $vectorInput,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix-vector multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
PimTensor: $vectorInput,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise addition: c = a + b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise max: c = max(a, b)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Apply filters to a tensor
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64ArrayAttr: $weightIndices,
|
||||
I64ArrayAttr: $xKernelPositions,
|
||||
I64ArrayAttr: $yKernelPositions,
|
||||
PimTensor: $input,
|
||||
PimTensor: $outBuf,
|
||||
PimTensor: $accumBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
|
||||
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Sum all elements into a single one
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dividend,
|
||||
PimTensor: $divisor,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise ReLU: c = max(a, 0)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise exp: c = exp(a)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimHaltOp: PimOp<"halt", [Terminator]> {
|
||||
let description = [{
|
||||
Halts the execution of the core
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
49
src/PIM/Dialect/Pim/PimOps.cpp
Normal file
49
src/PIM/Dialect/Pim/PimOps.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||
|
||||
>();
|
||||
}
|
||||
|
||||
#define POPULATE_DEPENDENCIES(OP_NAME) \
|
||||
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
|
||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
||||
}
|
||||
|
||||
POPULATE_DEPENDENCIES(PimVMaxOp)
|
||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
||||
POPULATE_DEPENDENCIES(PimSumOp)
|
||||
POPULATE_DEPENDENCIES(PimVSDivOp)
|
||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
||||
POPULATE_DEPENDENCIES(PimVExpOp)
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||
18
src/PIM/Dialect/Pim/PimOps.hpp
Normal file
18
src/PIM/Dialect/Pim/PimOps.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"
|
||||
21
src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt
Normal file
21
src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
||||
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMPimBufferization
|
||||
PimBufferizationPass.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
Common.hpp
|
||||
Common.cpp
|
||||
|
||||
DEPENDS
|
||||
PimBufferizationIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPIMCommon
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
9
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp
Normal file
9
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp
Normal file
@@ -0,0 +1,9 @@
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||
return builder.getI32IntegerAttr(sizeInBytes);
|
||||
}
|
||||
11
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp
Normal file
11
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,213 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
||||
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
|
||||
auto hostSrc = memCopyHostToDevOp.getHostSrc();
|
||||
|
||||
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
|
||||
if (failed(deviceDstOpt))
|
||||
return failure();
|
||||
auto deviceDstMemRef = *deviceDstOpt;
|
||||
|
||||
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
|
||||
if (failed(hostSrcOpt))
|
||||
return failure();
|
||||
auto hostSrcMemRef = *hostSrcOpt;
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceDstMemRef.getType(),
|
||||
deviceDstMemRef,
|
||||
hostSrcMemRef,
|
||||
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSrcOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyDevToHostOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||
|
||||
auto globalDst = memCopyDevToHostOp.getHostDst();
|
||||
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
|
||||
if (failed(globalDstOpt))
|
||||
return failure();
|
||||
auto globalDstMemRef = *globalDstOpt;
|
||||
|
||||
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
|
||||
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
|
||||
if (failed(localSrcOpt))
|
||||
return failure();
|
||||
auto localSrcMemRef = *localSrcOpt;
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||
memCopyDevToHostOp,
|
||||
globalDstMemRef.getType(),
|
||||
globalDstMemRef,
|
||||
localSrcMemRef,
|
||||
memCopyDevToHostOp.getHostDstOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
|
||||
memCopyDevToHostOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto transposeOp = cast<PimTransposeOp>(op);
|
||||
|
||||
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
|
||||
if (failed(dataOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
Value readVal = uRead->get();
|
||||
Value writeVal = uWrite->get();
|
||||
if (writeVal != vmmOp.getOutBuf())
|
||||
return false;
|
||||
if (readVal == vmmOp.getVectorInput())
|
||||
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
|
||||
if (failed(vectorInputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto mvmOp = cast<PimMVMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
|
||||
if (failed(vectorInputOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||
const AnalysisState& state,
|
||||
ArrayRef<OpOperand*> opOperands) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vaddOp = cast<PimVAddOp>(op);
|
||||
|
||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
||||
if (failed(aOpt))
|
||||
return failure();
|
||||
|
||||
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
|
||||
if (failed(bOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,19 @@
|
||||
#ifndef PIM_BUFFERIZATION
|
||||
#define PIM_BUFFERIZATION
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def memrefCopyToPimMemCopyOp : Pat<
|
||||
(CopyOp $src, $dst),
|
||||
(PimMemCopyOp $dst, $src,
|
||||
ConstantAttr<I32Attr, "0">,
|
||||
ConstantAttr<I32Attr, "0">,
|
||||
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
@@ -0,0 +1,110 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||
|
||||
PimBufferizationPass() = default;
|
||||
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimBufferizationPass::runOnOperation() {
|
||||
auto moduleOp = getOperation();
|
||||
|
||||
// One-Shot-Bufferization
|
||||
bufferization::OneShotBufferizationOptions options;
|
||||
options.allowUnknownOps = true;
|
||||
bufferization::BufferizationState state;
|
||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove toTensor operations: leave memrefs instead
|
||||
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
||||
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
||||
toTensorOp.erase();
|
||||
});
|
||||
|
||||
// Change main function return types from tensors to memrefs
|
||||
func::FuncOp funcOp;
|
||||
for (Operation& op : moduleOp.getBody()->getOperations())
|
||||
if ((funcOp = dyn_cast<func::FuncOp>(&op)))
|
||||
break;
|
||||
auto oldFuncType = funcOp.getFunctionType();
|
||||
SmallVector<Type> newResults;
|
||||
bool changed = false;
|
||||
for (Type type : oldFuncType.getResults())
|
||||
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
|
||||
newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
|
||||
changed = true;
|
||||
}
|
||||
else
|
||||
newResults.push_back(type);
|
||||
if (changed)
|
||||
funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults));
|
||||
|
||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "pim_buf");
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user