add constant folding and verification pass for pim host operations

better validation scripts output
big refactors
This commit is contained in:
NiccoloN
2026-03-20 12:08:12 +01:00
parent 4e50e056e3
commit 6e1de865bb
64 changed files with 1364 additions and 2265 deletions

View File

@@ -0,0 +1,16 @@
add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_onnx_mlir_library(PimOps
PimOps.hpp
PimOps.cpp
DEPENDS
OMPimIncGen
LINK_LIBS PUBLIC
OMMlirDialects
MLIRIR
)

391
src/PIM/Dialect/Pim/Pim.td Normal file
View File

@@ -0,0 +1,391 @@
#ifndef PIM_DIALECT_H
#define PIM_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
def PimDialect : Dialect {
let name = "pim";
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
let cppNamespace = "::onnx_mlir::pim";
}
class PimOp<string mnemonic, list<Trait> traits = []> :
Op<PimDialect, mnemonic, traits>;
def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
// Communication
def PimSendOp: PimOp<"send", []> {
let arguments = (ins
PimTensor: $src,
I32Attr: $size,
I32Attr: $targetCoreId
);
let assemblyFormat = [{
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
}];
}
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins
PimTensor: $dst,
I32Attr: $size,
I32Attr: $srcCoreId
);
let results = (outs
PimTensor: $out
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
}];
}
// Core
def PimCoreOp: PimOp<"core", [SingleBlock]> {
let regions = (region SizedRegion<1>:$body);
let arguments = (ins
Variadic<PimTensor>:$weights,
I32Attr: $coreId
);
let assemblyFormat = [{
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
}];
}
// Memory
def PimConstantOp: PimOp<"constant", []> {
let description = [{
Allocate a constant value in global memory
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
PimTensor: $out
);
}
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from host memory into device memory
}];
let arguments = (ins
PimTensor: $deviceDst,
PimTensor: $hostSrc,
I32Attr: $deviceDstOffset,
I32Attr: $hostSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $deviceDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceDstMutable();
}
}];
let assemblyFormat = [{
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
}];
}
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from device memory into host memory
}];
let arguments = (ins
PimTensor: $hostDst,
PimTensor: $deviceSrc,
I32Attr: $hostDstOffset,
I32Attr: $deviceSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $hostDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostDstMutable();
}
}];
let assemblyFormat = [{
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
}];
}
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from and to the same memory
}];
let arguments = (ins
PimTensor: $dst,
PimTensor: $src,
I32Attr: $dstOffset,
I32Attr: $srcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $dstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
}];
}
// Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
Vector-matrix multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
let description = [{
Matrix-vector multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise max: c = max(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Apply filters to a tensor
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
PimTensor: $input,
PimTensor: $outBuf,
PimTensor: $accumBuf
);
let results = (outs
PimTensor: $outRes
);
let assemblyFormat = [{
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
}];
}
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Sum all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
}];
let arguments = (ins
PimTensor: $dividend,
PimTensor: $divisor,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise ReLU: c = max(a, 0)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise exp: c = exp(a)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimHaltOp: PimOp<"halt", [Terminator]> {
let description = [{
Halts the execution of the core
}];
let assemblyFormat = [{
attr-dict
}];
}
#endif // PIM_DIALECT_H

View File

@@ -0,0 +1,49 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
>();
}
#define POPULATE_DEPENDENCIES(OP_NAME) \
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVMaxOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp)
} // namespace pim
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"

View File

@@ -0,0 +1,18 @@
#pragma once
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include <map>
#include <string>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"

View File

@@ -0,0 +1,21 @@
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization
PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp
Common.hpp
Common.cpp
DEPENDS
PimBufferizationIncGen
LINK_LIBS PUBLIC
OMPIMCommon
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -0,0 +1,9 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
return builder.getI32IntegerAttr(sizeInBytes);
}

View File

@@ -0,0 +1,11 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
namespace onnx_mlir {
namespace pim {
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,213 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace pim {
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc();
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
if (failed(deviceDstOpt))
return failure();
auto deviceDstMemRef = *deviceDstOpt;
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
if (failed(hostSrcOpt))
return failure();
auto hostSrcMemRef = *hostSrcOpt;
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceDstMemRef.getType(),
deviceDstMemRef,
hostSrcMemRef,
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
memCopyHostToDevOp.getHostSrcOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
};
struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst();
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
if (failed(globalDstOpt))
return failure();
auto globalDstMemRef = *globalDstOpt;
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
if (failed(localSrcOpt))
return failure();
auto localSrcMemRef = *localSrcOpt;
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
globalDstMemRef.getType(),
globalDstMemRef,
localSrcMemRef,
memCopyDevToHostOp.getHostDstOffsetAttr(),
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
};
struct TransposeOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
Value readVal = uRead->get();
Value writeVal = uWrite->get();
if (writeVal != vmmOp.getOutBuf())
return false;
if (readVal == vmmOp.getVectorInput())
if (state.areEquivalentBufferizedValues(readVal, writeVal))
return true;
return false;
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
return success();
}
};
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
return success();
}
};
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
const AnalysisState& state,
ArrayRef<OpOperand*> opOperands) const {
return true;
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
if (failed(aOpt))
return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
if (failed(bOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
return success();
}
};
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
});
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,19 @@
#ifndef PIM_BUFFERIZATION
#define PIM_BUFFERIZATION
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat<
(CopyOp $src, $dst),
(PimMemCopyOp $dst, $src,
ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">,
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION

View File

@@ -0,0 +1,110 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "Common/PIMCommon.hpp"
#include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
// One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure();
}
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
// Remove toTensor operations: leave memrefs instead
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
toTensorOp.erase();
});
// Change main function return types from tensors to memrefs
func::FuncOp funcOp;
for (Operation& op : moduleOp.getBody()->getOperations())
if ((funcOp = dyn_cast<func::FuncOp>(&op)))
break;
auto oldFuncType = funcOp.getFunctionType();
SmallVector<Type> newResults;
bool changed = false;
for (Type type : oldFuncType.getResults())
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
changed = true;
}
else
newResults.push_back(type);
if (changed)
funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults));
annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug
dumpModule(moduleOp, "pim_buf");
}
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant());
markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp);
}
});
}
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir