diff --git a/onnx-mlir b/onnx-mlir index 84cedd1..82018d7 160000 --- a/onnx-mlir +++ b/onnx-mlir @@ -1 +1 @@ -Subproject commit 84cedd1d690d1c04056caec7ba29be1abea91815 +Subproject commit 82018d7ce59c94bfbe9479b16538224969fa45a0 diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index 6a704c7..d909e05 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -10,14 +10,13 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY}) set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT}) set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT}) -add_subdirectory(Dialect) +add_subdirectory(Common) add_subdirectory(Compiler) add_subdirectory(Conversion) -add_subdirectory(Common) +add_subdirectory(Dialect) add_onnx_mlir_library(OMPIMAccel PimAccelerator.cpp - Transforms/PimBufferizationPass.cpp Pass/CountInstructionPass.cpp Pass/EmitPimJsonPass.cpp Pass/MessagePass.cpp diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index fbbe8e0..1deb83f 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -200,6 +200,15 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const { storeOp.getSize()); } +void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const { + emitMemCopyOp("lmv", + memory.getValueAddress(lmvOp.getDst()), + lmvOp.getDstOffset(), + memory.getValueAddress(lmvOp.getSrc()), + lmvOp.getSrcOffset(), + lmvOp.getSize()); +} + void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const { emitCommunicationOp( "recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize()); @@ -343,7 +352,6 @@ std::string getMemorySizeAsString(size_t size) { /// Write global constant data into a binary memory image at their allocated addresses. static OnnxMlirCompilerErrorCodes writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { - auto memoryFilePath = (outputDirPath + "/memory.bin").str(); std::error_code errorCode; raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None); @@ -400,6 +408,12 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenLoadOp(loadOp); else if (auto storeOp = dyn_cast(op)) coreCodeGen.codeGenStoreOp(storeOp); + else if (auto lmvOp = dyn_cast(op)) + coreCodeGen.codeGenLmvOp(lmvOp); + else if (auto receiveOp = dyn_cast(op)) + coreCodeGen.codeGenReceiveOp(receiveOp); + else if (auto sendOp = dyn_cast(op)) + coreCodeGen.codeGenSendOp(sendOp); else if (auto vmmOp = dyn_cast(op)) coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true); else if (auto mvmOp = dyn_cast(op)) @@ -412,10 +426,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVMaxOp(vmaxOp); else if (auto vreluOp = dyn_cast(op)) coreCodeGen.codeGenVReluOp(vreluOp); - else if (auto receiveOp = dyn_cast(op)) - coreCodeGen.codeGenReceiveOp(receiveOp); - else if (auto sendOp = dyn_cast(op)) - coreCodeGen.codeGenSendOp(sendOp); else if (isa(op)) { // TODO: Implement somehow? op.emitWarning("Operation is not yet supported in code generation"); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 585099b..6a2effb 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -84,6 +84,7 @@ public: void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const; + void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const; void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const; void codeGenSendOp(pim::PimSendOp sendOp) const; diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp index c387ed8..cd63e75 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -25,9 +25,96 @@ namespace onnx_mlir { const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; -struct ONNXGemmOpTile : public OpConversionPattern { - ONNXGemmOpTile(MLIRContext* ctx) - : OpConversionPattern(ctx) {} +struct GemmToManyGemv : OpConversionPattern { + GemmToManyGemv(MLIRContext* ctx) + : OpConversionPattern(ctx, 2) {} + + LogicalResult + matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { + Location loc = gemmOp.getLoc(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + + assert("A should have been transposed already" && !adaptor.getTransA()); + + bool hasC = !isa(c.getDefiningOp()); + + auto aType = cast(a.getType()); + auto outType = cast(gemmOp.getY().getType()); + assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); + + const int64_t numOutRows = aType.getDimSize(0); + + // Only decompose when there are multiple rows to split + if (numOutRows <= 1) + return failure(); + + RankedTensorType cType = nullptr; + bool cHasNumOutRows = false; + if (hasC) { + cType = cast(c.getType()); + assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + cHasNumOutRows = cType.getDimSize(0) == numOutRows; + } + + auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); + + SmallVector gemvOps; + gemvOps.reserve(numOutRows); + for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { + SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); + auto aSlice = rewriter.create(loc, aSliceType, a, offsets, sizes, strides).getResult(); + + Value cSlice = c; + if (hasC) { + if (cHasNumOutRows) { + SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); + cSlice = rewriter.create(loc, cSliceType, c, offsets, sizes, strides).getResult(); + } + else + assert("C should be a vector" && isVectorShape(getTensorShape(c))); + } + + auto gemvOp = rewriter.create(loc, + outRowType, + aSlice, + b, + cSlice, + gemmOp.getAlphaAttr(), + gemmOp.getBetaAttr(), + gemmOp.getTransAAttr(), + gemmOp.getTransBAttr()); + gemvOps.push_back(gemvOp.getY()); + } + + auto concatComputeOp = + rewriter.create(loc, gemmOp.getType(), SmallVector(), gemvOps); + + auto* concatBlock = new Block(); + for (auto gemvOp : gemvOps) + concatBlock->addArgument(gemvOp.getType(), loc); + concatComputeOp.getBody().push_back(concatBlock); + rewriter.setInsertionPointToStart(concatBlock); + + auto blockArgs = concatBlock->getArguments(); + auto concatOp = rewriter.create(loc, /*axis=*/0, blockArgs); + rewriter.create(loc, concatOp.getResult()); + + rewriter.replaceOp(gemmOp, concatComputeOp); + return success(); + } +}; + +struct GemvToSpatialCompute : OpConversionPattern { + GemvToSpatialCompute(MLIRContext* ctx) + : OpConversionPattern(ctx, 1) {} LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { @@ -50,12 +137,16 @@ struct ONNXGemmOpTile : public OpConversionPattern { bool hasC = !isa(c.getDefiningOp()); if (hasC) { cType = cast(c.getType()); - assert("Only support 2 tensor for C" && cType.getRank() == 2); + assert("Only support rank 2 tensor for C" && cType.getRank() == 2); } assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); + if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape())) + // Not a gemv + return failure(); + if (transA) { auto aShape = aType.getShape(); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); @@ -169,9 +260,20 @@ struct ONNXGemmOpTile : public OpConversionPattern { outHSlices.push_back(reduceComputeOp.getResult(0)); } - rewriter.setInsertionPoint(gemmOp); - auto concatOp = rewriter.create(gemmLoc, /*axis=*/1, outHSlices); - rewriter.replaceOp(gemmOp, concatOp); + auto concatComputeOp = + rewriter.create(gemmLoc, gemmOp.getType(), SmallVector(), outHSlices); + + auto* concatBlock = new Block(); + for (auto outHSlice : outHSlices) + concatBlock->addArgument(outHSlice.getType(), gemmLoc); + concatComputeOp.getBody().push_back(concatBlock); + rewriter.setInsertionPointToStart(concatBlock); + + auto blockArgs = concatBlock->getArguments(); + auto concatOp = rewriter.create(gemmLoc, /*axis=*/1, blockArgs); + rewriter.create(gemmLoc, concatOp.getResult()); + + rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } @@ -310,8 +412,9 @@ private: } }; -void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); +void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert(ctx); + patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 64fdb5e..70a8fbe 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -71,7 +71,7 @@ void ONNXToSpatialPass::runOnOperation() { else { populateTilingConvOpPattern(patterns, ctx); populatePoolingTilingPattern(patterns, ctx); - populateTilingGemmOpPattern(patterns, ctx); + populateOnnxGemmOpPatterns(patterns, ctx); } populateONNXConcatToTensorConcatPattern(patterns, ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp index 014cbd5..b496dfd 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp @@ -5,7 +5,7 @@ namespace onnx_mlir { void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp index bbaeabd..5b590b5 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -19,10 +19,9 @@ #include "SpatialToPIMPass.hpp" using namespace mlir; - -namespace onnx_mlir { - -namespace pim { +using namespace onnx_mlir; +using namespace pim; +using namespace spat_to_pim; void SpatialToPIMPass::runOnOperation() { coreId = 1; @@ -409,6 +408,7 @@ void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp, if (!computeUser) { auto reshapeOp = dyn_cast(resultUse.getOwner()); if (!reshapeOp) { + channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump(); resultUse.getOwner()->dump(); llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); } @@ -479,7 +479,3 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getData()); } } - -} // namespace pim - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp index 37f0826..31c88a2 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp @@ -10,7 +10,7 @@ namespace onnx_mlir { -namespace pim { +namespace spat_to_pim { #include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" @@ -53,8 +53,8 @@ private: void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); }; -} // namespace pim +} // namespace spat_to_pim -std::unique_ptr createSpatialToPIMPass() { return std::make_unique(); } +std::unique_ptr createSpatialToPIMPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/CMakeLists.txt b/src/PIM/Dialect/PIM/CMakeLists.txt index 29f995e..e46ed57 100644 --- a/src/PIM/Dialect/PIM/CMakeLists.txt +++ b/src/PIM/Dialect/PIM/CMakeLists.txt @@ -1,10 +1,11 @@ add_onnx_mlir_dialect(Pim pim) add_onnx_mlir_dialect_doc(pim Pim.td) +add_subdirectory(Transforms/Bufferization) add_onnx_mlir_library(PimOps + PimOps.hpp PimOps.cpp - Transforms/PimBufferizableOpInterface.cpp DEPENDS OMPimIncGen diff --git a/src/PIM/Dialect/PIM/Pim.td b/src/PIM/Dialect/PIM/Pim.td index fb9ac57..e7c4603 100644 --- a/src/PIM/Dialect/PIM/Pim.td +++ b/src/PIM/Dialect/PIM/Pim.td @@ -14,20 +14,13 @@ def PimDialect : Dialect { let cppNamespace = "::onnx_mlir::pim"; } -// Base class for Pim dialect operations. This operation inherits from the -// base `Op` class in OpBase.td, and provides: -// * The parent dialect of the operation. -// * The mnemonic for the operation, or the name without the dialect prefix. -// * A list of traits for the operation. class PimOp traits = []> : Op; def PimTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; -//===----------------------------------------------------------------------===// -// Communication operations -//===----------------------------------------------------------------------===// +// Communication def PimSendOp: PimOp<"send", []> { let arguments = (ins @@ -63,9 +56,7 @@ def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> { }]; } -//===----------------------------------------------------------------------===// -// Core operations -//===----------------------------------------------------------------------===// +// Core def PimCoreOp: PimOp<"core", [SingleBlock]> { @@ -81,9 +72,7 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> { }]; } -//===----------------------------------------------------------------------===// -// Memory Operations -//===----------------------------------------------------------------------===// +// Memory def PimConstantOp: PimOp<"constant", []> { let description = [{ @@ -157,9 +146,36 @@ def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> { }]; } -//===----------------------------------------------------------------------===// -// Core.Compute operations -//===----------------------------------------------------------------------===// +def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> { + let description = [{ + Copy a memory region from and to the same memory + }]; + + let arguments = (ins + PimTensor: $dst, + PimTensor: $src, + I32Attr: $dstOffset, + I32Attr: $srcOffset, + I32Attr: $size + ); + + let results = (outs + PimTensor: $dstOut + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getDstMutable(); + } + }]; + + + let assemblyFormat = [{ + `(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut) + }]; +} + +// Computation def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { let description = [{ diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt b/src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt new file mode 100644 index 0000000..eda1784 --- /dev/null +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt @@ -0,0 +1,22 @@ +set(LLVM_TARGET_DEFINITIONS PimBufferization.td) +mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") +add_public_tablegen_target(PimBufferizationIncGen) + +add_onnx_mlir_library(OMPimBufferization + PimBufferizationPass.hpp + PimBufferizationPass.cpp + OpBufferizationInterfaces.hpp + OpBufferizationInterfaces.cpp + Common.hpp + Common.cpp + + DEPENDS + PimBufferizationIncGen + + LINK_LIBS PUBLIC + OMPIMCommon + PimOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_INCLUDE_PATH} +) diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp new file mode 100644 index 0000000..d6480fe --- /dev/null +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp @@ -0,0 +1,9 @@ +#include "Dialect/PIM/Transforms/Bufferization/Common.hpp" + +using namespace mlir; + +IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) { + auto type = mlir::cast(memref.getType()); + int32_t sizeInBytes = static_cast(type.getNumElements() * type.getElementTypeBitWidth() / 8); + return builder.getI32IntegerAttr(sizeInBytes); +} diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp new file mode 100644 index 0000000..5bbd3ba --- /dev/null +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; + +namespace onnx_mlir { +namespace pim { + +IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref); + +} // namespace pim +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.cpp similarity index 97% rename from src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp rename to src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 1a802b0..64109fd 100644 --- a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -1,11 +1,10 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp" using namespace mlir; using namespace bufferization; @@ -173,7 +172,7 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp similarity index 73% rename from src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp rename to src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp index f4872e0..16a3a42 100644 --- a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp @@ -9,7 +9,7 @@ using namespace mlir; namespace onnx_mlir { namespace pim { -void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); +void registerOpBufferizationInterfaces(DialectRegistry& registry); } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferization.td b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferization.td new file mode 100644 index 0000000..78a9c03 --- /dev/null +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferization.td @@ -0,0 +1,19 @@ +#ifndef PIM_BUFFERIZATION +#define PIM_BUFFERIZATION + +#ifndef OP_BASE +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/MemRef/IR/MemRefOps.td" +include "src/Accelerators/PIM/Dialect/PIM/Pim.td" +#endif // OP_BASE + +def memrefCopyToPimMemCopyOp : Pat< + (CopyOp $src, $dst), + (PimMemCopyOp $dst, $src, + ConstantAttr, + ConstantAttr, + (NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src), + (returnType $dst)) +>; + +#endif // PIM_BUFFERIZATION \ No newline at end of file diff --git a/src/PIM/Transforms/PimBufferizationPass.cpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.cpp similarity index 79% rename from src/PIM/Transforms/PimBufferizationPass.cpp rename to src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.cpp index 9173843..7f5eb82 100644 --- a/src/PIM/Transforms/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.cpp @@ -5,23 +5,18 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" -#include "llvm/Support/raw_os_ostream.h" - #include "Common/PIMCommon.hpp" #include "Compiler/PimCodeGen.hpp" #include "PimBufferizationPass.hpp" -#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; - -namespace onnx_mlir { - -namespace pim { +using namespace onnx_mlir; +using namespace pim; void PimBufferizationPass::runOnOperation() { auto moduleOp = getOperation(); - // Do One-Shot-Bufferization + // One-Shot-Bufferization bufferization::OneShotBufferizationOptions options; options.allowUnknownOps = true; bufferization::BufferizationState state; @@ -30,7 +25,19 @@ void PimBufferizationPass::runOnOperation() { signalPassFailure(); } - // Remove toTensor operations + MLIRContext* ctx = moduleOp.getContext(); + ConversionTarget target(*ctx); + target.addLegalDialect(); + + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + + // Remove toTensor operations: leave memrefs instead moduleOp.walk([](bufferization::ToTensorOp toTensorOp) { toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer()); toTensorOp.erase(); @@ -63,8 +70,8 @@ void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { MLIRContext* ctx = funcOp.getContext(); funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { - bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }) - && !getGlobalOp->getUsers().empty(); + bool isAlwaysWeight = !getGlobalOp->getUsers().empty() + && all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }); if (isAlwaysWeight) { auto globalMemrefOp = moduleOp.lookupSymbol(getGlobalOp.getName()); assert("Weights must be constants" && globalMemrefOp.getConstant()); @@ -73,7 +80,3 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO } }); } - -} // namespace pim - -} // namespace onnx_mlir diff --git a/src/PIM/Transforms/PimBufferizationPass.hpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.hpp similarity index 83% rename from src/PIM/Transforms/PimBufferizationPass.hpp rename to src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.hpp index 3462b48..3e5e327 100644 --- a/src/PIM/Transforms/PimBufferizationPass.hpp +++ b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.hpp @@ -2,6 +2,8 @@ #include "mlir/Pass/Pass.h" +#include "Dialect/PIM/PimOps.hpp" +#include "Dialect/PIM/Transforms/Bufferization/Common.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerOptions.hpp" @@ -9,6 +11,8 @@ namespace onnx_mlir { namespace pim { +#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc" + struct PimBufferizationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) StringRef getArgument() const override { return "bufferize-pim"; } diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 13ce3a6..d50dd30 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -13,7 +13,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" @@ -64,7 +64,7 @@ void PimAccelerator::registerDialects(DialectRegistry& registry) const { bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); - pim::registerBufferizableOpInterfaceExternalModels(registry); + pim::registerOpBufferizationInterfaces(registry); } void PimAccelerator::registerPasses(int optLevel) const { diff --git a/validation/operations/gemm/constant/GemmConst.onnx b/validation/operations/gemm/constant/GemmConst.onnx deleted file mode 100644 index a4d3176..0000000 Binary files a/validation/operations/gemm/constant/GemmConst.onnx and /dev/null differ diff --git a/validation/operations/gemm/gemm.onnx b/validation/operations/gemm/gemm.onnx new file mode 100644 index 0000000..01b21b6 Binary files /dev/null and b/validation/operations/gemm/gemm.onnx differ diff --git a/validation/operations/gemv/constant/gemv_constant.onnx b/validation/operations/gemv/constant/gemv_constant.onnx new file mode 100644 index 0000000..0938c08 Binary files /dev/null and b/validation/operations/gemv/constant/gemv_constant.onnx differ diff --git a/validation/operations/gemm/simple/gemm_simple.onnx b/validation/operations/gemv/simple/gemv_simple.onnx similarity index 100% rename from validation/operations/gemm/simple/gemm_simple.onnx rename to validation/operations/gemv/simple/gemv_simple.onnx diff --git a/validation/operations/gemm/with_heterogeneous_constant/gemm_with_heterogeneous_constant.onnx b/validation/operations/gemv/with_heterogeneous_constant/gemv_with_heterogeneous_constant.onnx similarity index 100% rename from validation/operations/gemm/with_heterogeneous_constant/gemm_with_heterogeneous_constant.onnx rename to validation/operations/gemv/with_heterogeneous_constant/gemv_with_heterogeneous_constant.onnx diff --git a/validation/operations/gemm/with_homogeneous_constant/gemm_with_homogeneous_constant.onnx b/validation/operations/gemv/with_homogeneous_constant/gemv_with_homogeneous_constant.onnx similarity index 100% rename from validation/operations/gemm/with_homogeneous_constant/gemm_with_homogeneous_constant.onnx rename to validation/operations/gemv/with_homogeneous_constant/gemv_with_homogeneous_constant.onnx diff --git a/validation/operations/gemm/with_scalar_constant/gemm_with_scalar_constant.onnx b/validation/operations/gemv/with_scalar_constant/gemv_with_scalar_constant.onnx similarity index 100% rename from validation/operations/gemm/with_scalar_constant/gemm_with_scalar_constant.onnx rename to validation/operations/gemv/with_scalar_constant/gemv_with_scalar_constant.onnx