implement mem copy codgen (lmv)
add more gemv/gemm tests refactor
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
add_onnx_mlir_dialect(Pim pim)
|
||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
|
||||
add_onnx_mlir_library(PimOps
|
||||
PimOps.hpp
|
||||
PimOps.cpp
|
||||
Transforms/PimBufferizableOpInterface.cpp
|
||||
|
||||
DEPENDS
|
||||
OMPimIncGen
|
||||
|
||||
@@ -14,20 +14,13 @@ def PimDialect : Dialect {
|
||||
let cppNamespace = "::onnx_mlir::pim";
|
||||
}
|
||||
|
||||
// Base class for Pim dialect operations. This operation inherits from the
|
||||
// base `Op` class in OpBase.td, and provides:
|
||||
// * The parent dialect of the operation.
|
||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||
// * A list of traits for the operation.
|
||||
class PimOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<PimDialect, mnemonic, traits>;
|
||||
|
||||
def PimTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
|
||||
def PimSendOp: PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
@@ -63,9 +56,7 @@ def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core
|
||||
|
||||
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||
|
||||
@@ -81,9 +72,7 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memory Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memory
|
||||
|
||||
def PimConstantOp: PimOp<"constant", []> {
|
||||
let description = [{
|
||||
@@ -157,9 +146,36 @@ def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core.Compute operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Copy a memory region from and to the same memory
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dst,
|
||||
PimTensor: $src,
|
||||
I32Attr: $dstOffset,
|
||||
I32Attr: $srcOffset,
|
||||
I32Attr: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $dstOut
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDstMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
|
||||
}];
|
||||
}
|
||||
|
||||
// Computation
|
||||
|
||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
|
||||
22
src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt
Normal file
22
src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
||||
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMPimBufferization
|
||||
PimBufferizationPass.hpp
|
||||
PimBufferizationPass.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
Common.hpp
|
||||
Common.cpp
|
||||
|
||||
DEPENDS
|
||||
PimBufferizationIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPIMCommon
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
9
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp
Normal file
9
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp
Normal file
@@ -0,0 +1,9 @@
|
||||
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||
return builder.getI32IntegerAttr(sizeInBytes);
|
||||
}
|
||||
13
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp
Normal file
13
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,11 +1,10 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
@@ -173,7 +172,7 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
|
||||
}
|
||||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
@@ -9,7 +9,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,19 @@
|
||||
#ifndef PIM_BUFFERIZATION
|
||||
#define PIM_BUFFERIZATION
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def memrefCopyToPimMemCopyOp : Pat<
|
||||
(CopyOp $src, $dst),
|
||||
(PimMemCopyOp $dst, $src,
|
||||
ConstantAttr<I32Attr, "0">,
|
||||
ConstantAttr<I32Attr, "0">,
|
||||
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
@@ -0,0 +1,82 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
#include "PimBufferizationPass.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
void PimBufferizationPass::runOnOperation() {
|
||||
auto moduleOp = getOperation();
|
||||
|
||||
// One-Shot-Bufferization
|
||||
bufferization::OneShotBufferizationOptions options;
|
||||
options.allowUnknownOps = true;
|
||||
bufferization::BufferizationState state;
|
||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove toTensor operations: leave memrefs instead
|
||||
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
||||
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
||||
toTensorOp.erase();
|
||||
});
|
||||
|
||||
// Change main function return types from tensors to memrefs
|
||||
func::FuncOp funcOp;
|
||||
for (Operation& op : moduleOp.getBody()->getOperations())
|
||||
if ((funcOp = dyn_cast<func::FuncOp>(&op)))
|
||||
break;
|
||||
auto oldFuncType = funcOp.getFunctionType();
|
||||
SmallVector<Type> newResults;
|
||||
bool changed = false;
|
||||
for (Type type : oldFuncType.getResults())
|
||||
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
|
||||
newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
|
||||
changed = true;
|
||||
}
|
||||
else
|
||||
newResults.push_back(type);
|
||||
if (changed)
|
||||
funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults));
|
||||
|
||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "pim_buf");
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
MLIRContext* ctx = funcOp.getContext();
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Dialect/PIM/PimOps.hpp"
|
||||
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace pim {
|
||||
|
||||
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||
|
||||
PimBufferizationPass() = default;
|
||||
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace pim
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user