From a6e928bdd70e4ea04ca888ccd24a4385817541d7 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 24 Feb 2026 15:09:18 +0100 Subject: [PATCH] add PIM accelerator --- CMakeLists.txt | 27 + onnx-mlir | 2 +- src/PIM/CMakeLists.txt | 46 ++ src/PIM/Common/CMakeLists.txt | 19 + src/PIM/Common/PIMCommon.cpp | 67 ++ src/PIM/Common/PIMCommon.hpp | 16 + src/PIM/Common/ValueMap.hpp | 44 ++ src/PIM/Compiler/CMakeLists.txt | 44 ++ src/PIM/Compiler/PimCodeGen.cpp | 704 ++++++++++++++++++ src/PIM/Compiler/PimCodeGen.hpp | 97 +++ src/PIM/Compiler/PimCompilerOptions.cpp | 56 ++ src/PIM/Compiler/PimCompilerOptions.hpp | 42 ++ src/PIM/Compiler/PimCompilerUtils.cpp | 56 ++ src/PIM/Compiler/PimCompilerUtils.hpp | 19 + src/PIM/Conversion/CMakeLists.txt | 3 + .../Conversion/ONNXToSpatial/CMakeLists.txt | 34 + .../Conversion/ONNXToSpatial/Math/Conv.cpp | 624 ++++++++++++++++ .../ONNXToSpatial/Math/ExperimentalConv.cpp | 430 +++++++++++ .../ONNXToSpatial/Math/ExperimentalGemm.cpp | 400 ++++++++++ .../Conversion/ONNXToSpatial/Math/Gemm.cpp | 317 ++++++++ .../ONNXToSpatial/NN/ExperimentalPooling.cpp | 327 ++++++++ .../Conversion/ONNXToSpatial/NN/Pooling.cpp | 452 +++++++++++ .../ONNXToSpatial/NN/ReduceMean.cpp | 90 +++ .../Conversion/ONNXToSpatial/ONNXToSpatial.td | 79 ++ .../ONNXToSpatial/ONNXToSpatialCommon.cpp | 499 +++++++++++++ .../ONNXToSpatial/ONNXToSpatialCommon.hpp | 262 +++++++ .../ONNXToSpatial/ONNXToSpatialPass.cpp | 131 ++++ .../ONNXToSpatial/ONNXToSpatialPass.hpp | 34 + .../ONNXToSpatial/ONNXToSpatialPatterns.hpp | 40 + .../Tensor/ONNXConcatToTensorConcat.cpp | 31 + .../Tensor/RemoveUnusedHelperOps.cpp | 34 + .../Utils/AnnotateReplication.cpp | 119 +++ .../Utils/AnnotateReplication.hpp | 11 + .../ONNXToSpatial/Utils/SpatialReducer.cpp | 382 ++++++++++ .../ONNXToSpatial/Utils/SpatialReducer.hpp | 83 +++ .../ONNXToSpatial/Utils/WeightSubdivider.cpp | 53 ++ .../ONNXToSpatial/Utils/WeightSubdivider.hpp | 46 ++ .../SpatialToGraphviz/CMakeLists.txt | 14 + .../SpatialToGraphviz/SpatialToGraphviz.cpp | 283 +++++++ .../Conversion/SpatialToPIM/CMakeLists.txt | 21 + .../Conversion/SpatialToPIM/SpatialToPIM.td | 28 + .../SpatialToPIM/SpatialToPIMCommon.cpp | 97 +++ .../SpatialToPIM/SpatialToPIMCommon.hpp | 108 +++ .../SpatialToPIM/SpatialToPIMPass.cpp | 491 ++++++++++++ .../SpatialToPIM/SpatialToPIMPass.hpp | 60 ++ .../SpatialToPIM/SpatialToPIMPatterns.hpp | 12 + src/PIM/Dialect/CMakeLists.txt | 2 + src/PIM/Dialect/PIM/CMakeLists.txt | 15 + src/PIM/Dialect/PIM/Pim.td | 345 +++++++++ src/PIM/Dialect/PIM/PimOps.cpp | 49 ++ src/PIM/Dialect/PIM/PimOps.hpp | 18 + .../Transforms/PimBufferizableOpInterface.cpp | 172 +++++ .../Transforms/PimBufferizableOpInterface.hpp | 14 + src/PIM/Dialect/Spatial/CMakeLists.txt | 15 + src/PIM/Dialect/Spatial/Spatial.td | 355 +++++++++ src/PIM/Dialect/Spatial/SpatialOps.cpp | 339 +++++++++ src/PIM/Dialect/Spatial/SpatialOps.hpp | 20 + .../SpatialBufferizableOpInterface.cpp | 493 ++++++++++++ .../SpatialBufferizableOpInterface.hpp | 16 + src/PIM/Pass/CountInstructionPass.cpp | 67 ++ src/PIM/Pass/MessagePass.cpp | 37 + src/PIM/Pass/PimPasses.hpp | 22 + src/PIM/PimAccelerator.cpp | 110 +++ src/PIM/PimAccelerator.hpp | 70 ++ src/PIM/Transforms/PimBufferizationPass.cpp | 87 +++ src/PIM/Transforms/PimBufferizationPass.hpp | 30 + test/PIM/CMakeLists.txt | 0 67 files changed, 9109 insertions(+), 1 deletion(-) create mode 100644 src/PIM/CMakeLists.txt create mode 100644 src/PIM/Common/CMakeLists.txt create mode 100644 src/PIM/Common/PIMCommon.cpp create mode 100644 src/PIM/Common/PIMCommon.hpp create mode 100644 src/PIM/Common/ValueMap.hpp create mode 100644 src/PIM/Compiler/CMakeLists.txt create mode 100644 src/PIM/Compiler/PimCodeGen.cpp create mode 100644 src/PIM/Compiler/PimCodeGen.hpp create mode 100644 src/PIM/Compiler/PimCompilerOptions.cpp create mode 100644 src/PIM/Compiler/PimCompilerOptions.hpp create mode 100644 src/PIM/Compiler/PimCompilerUtils.cpp create mode 100644 src/PIM/Compiler/PimCompilerUtils.hpp create mode 100644 src/PIM/Conversion/CMakeLists.txt create mode 100644 src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt create mode 100644 src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp create mode 100644 src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt create mode 100644 src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp create mode 100644 src/PIM/Conversion/SpatialToPIM/CMakeLists.txt create mode 100644 src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td create mode 100644 src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp create mode 100644 src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp create mode 100644 src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp create mode 100644 src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp create mode 100644 src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp create mode 100644 src/PIM/Dialect/CMakeLists.txt create mode 100644 src/PIM/Dialect/PIM/CMakeLists.txt create mode 100644 src/PIM/Dialect/PIM/Pim.td create mode 100644 src/PIM/Dialect/PIM/PimOps.cpp create mode 100644 src/PIM/Dialect/PIM/PimOps.hpp create mode 100644 src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp create mode 100644 src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp create mode 100644 src/PIM/Dialect/Spatial/CMakeLists.txt create mode 100644 src/PIM/Dialect/Spatial/Spatial.td create mode 100644 src/PIM/Dialect/Spatial/SpatialOps.cpp create mode 100644 src/PIM/Dialect/Spatial/SpatialOps.hpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp create mode 100644 src/PIM/Pass/CountInstructionPass.cpp create mode 100644 src/PIM/Pass/MessagePass.cpp create mode 100644 src/PIM/Pass/PimPasses.hpp create mode 100644 src/PIM/PimAccelerator.cpp create mode 100644 src/PIM/PimAccelerator.hpp create mode 100644 src/PIM/Transforms/PimBufferizationPass.cpp create mode 100644 src/PIM/Transforms/PimBufferizationPass.hpp create mode 100644 test/PIM/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index bdccafc..3a5c39e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,4 +3,31 @@ cmake_minimum_required(VERSION 3.20.0) project(raptor) +# Add symlink to PIM as accelerator in onnx-mlir +function(raptor_ensure_symlink link_path target_path) + get_filename_component(link_parent "${link_path}" DIRECTORY) + + if(NOT EXISTS "${link_parent}") + message(FATAL_ERROR "Directory not found: ${link_parent}") + endif() + + if(NOT EXISTS "${link_path}") + message(STATUS "Creating symlink ${link_path} -> ${target_path}") + file(CREATE_LINK + "${target_path}" + "${link_path}" + SYMBOLIC + ) + endif() +endfunction() + +raptor_ensure_symlink( + "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM" + "${CMAKE_CURRENT_SOURCE_DIR}/src/PIM" +) +raptor_ensure_symlink( + "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM" + "${CMAKE_CURRENT_SOURCE_DIR}/test/PIM" +) + add_subdirectory(onnx-mlir) diff --git a/onnx-mlir b/onnx-mlir index f7897a0..840d057 160000 --- a/onnx-mlir +++ b/onnx-mlir @@ -1 +1 @@ -Subproject commit f7897a0cab2a077be1bb9bd2fb87bbacf4ca391b +Subproject commit 840d05752072bf9d0488c00c4389c43a7e5dc6df diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt new file mode 100644 index 0000000..03d8cc9 --- /dev/null +++ b/src/PIM/CMakeLists.txt @@ -0,0 +1,46 @@ +set(PIM_ENABLED 1 BOOL PARENT_SCOPE) + +set(PIM_SRC_ROOT "${CMAKE_CURRENT_SOURCE_DIR}") +set(PIM_BIN_ROOT "${CMAKE_CURRENT_BINARY_DIR}") + +set(PIM_LIBRARY_PATH ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +set(PIM_RUNTIME_PATH ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY}) + +set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT}) +set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT}) + +add_subdirectory(Dialect) +add_subdirectory(Compiler) +add_subdirectory(Conversion) +add_subdirectory(Common) + +add_onnx_mlir_library(OMPIMAccel + PimAccelerator.cpp + Transforms/PimBufferizationPass.cpp + Pass/MessagePass.cpp + Pass/CountInstructionPass.cpp + + EXCLUDE_FROM_OM_LIBS + + INCLUDE_DIRS PUBLIC + ${ONNX_MLIR_SRC_ROOT}/include + ${ONNX_MLIR_SRC_ROOT} + ${PIM_ONNX_MLIR_SRC_ROOT} + ${PIM_SRC_ROOT} + ${PIM_BIN_ROOT} + ${PIM_INCLUDE_PATH} + + LINK_LIBS PUBLIC + onnx + OMAccelerator + OMPimCompilerUtils + OMCompilerUtils + OMONNXOps + SpatialOps + PimOps + OMONNXToSpatial + OMSpatialToGraphviz + OMSpatialToPIM + OMPIMCommon +) \ No newline at end of file diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt new file mode 100644 index 0000000..5ca75f5 --- /dev/null +++ b/src/PIM/Common/CMakeLists.txt @@ -0,0 +1,19 @@ +add_onnx_mlir_library(OMPIMCommon + PIMCommon.cpp + + EXCLUDE_FROM_OM_LIBS + + INCLUDE_DIRS PUBLIC + ${ONNX_MLIR_SRC_ROOT}/include + ${ONNX_MLIR_SRC_ROOT} + ${PIM_ONNX_MLIR_SRC_ROOT} + ${PIM_SRC_ROOT} + ${PIM_BIN_ROOT} + ${PIM_INCLUDE_PATH} + + LINK_LIBS PUBLIC + onnx + OMPimCompilerUtils + SpatialOps + PimOps +) \ No newline at end of file diff --git a/src/PIM/Common/PIMCommon.cpp b/src/PIM/Common/PIMCommon.cpp new file mode 100644 index 0000000..f858003 --- /dev/null +++ b/src/PIM/Common/PIMCommon.cpp @@ -0,0 +1,67 @@ +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +llvm::FailureOr getOtherEndOfChannel( + Operation *op, bool opIsReceive, RewriterBase &rewriter) { + + auto channelNewOp = op->getOperand(0).getDefiningOp(); + if (!channelNewOp) { + op->emitError( + "User of Channel must have the first operand created by ChannelNewOp."); + return failure(); + } + // channelNewOp should have two users: `op` and a + // `ChannelSendOp`/`ChannelReceiveOp` + auto channelUsers = channelNewOp->getUsers(); + auto usersIterator = channelUsers.begin(); + auto firstUser = *usersIterator; + usersIterator++; + if (usersIterator == channelUsers.end()) { + op->emitError("Operand generated by ChannelNewOp must have two users, " + "only one found."); + channelNewOp->dump(); + op->dump(); + channelNewOp->getParentOp()->dump(); + return failure(); + } + auto secondUser = *usersIterator; + usersIterator++; + if (usersIterator != channelUsers.end()) { + op->emitError("Operand generated by ChannelNewOp must have two users, " + "more than two found."); + return failure(); + } + Operation *notOpUser; + if (firstUser == op) { + notOpUser = secondUser; + } else if (secondUser == op) { + notOpUser = firstUser; + } else { + op->emitError("Operand generated by ChannelNewOp must have two users, " + "and one of them must be me, but" + "none of them is actually me."); + return failure(); + } + + if (opIsReceive) { + if (!isa(notOpUser)) { + op->emitError("Operand generated by ChannelNewOp has two user, one is " + "me, the other is not a ChannelSendOp."); + return failure(); + } + return notOpUser; + } else { + if (!isa(notOpUser)) { + op->emitError("Operand generated by ChannelNewOp has two user, one is " + "me, the other is not a ChannelReceiveOp."); + return failure(); + } + return notOpUser; + } +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Common/PIMCommon.hpp b/src/PIM/Common/PIMCommon.hpp new file mode 100644 index 0000000..2ae0e30 --- /dev/null +++ b/src/PIM/Common/PIMCommon.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/StringRef.h" + +const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = + "pim.constant.should_allocate"; + +namespace onnx_mlir { + +llvm::FailureOr getOtherEndOfChannel( + mlir::Operation *op, bool opIsReceive, mlir::RewriterBase &rewriter); + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Common/ValueMap.hpp b/src/PIM/Common/ValueMap.hpp new file mode 100644 index 0000000..c8e249a --- /dev/null +++ b/src/PIM/Common/ValueMap.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DenseMap.h" + +template +class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener { +public: + llvm::DenseMap map; + + AutoCleaningValueMap(mlir::OpBuilder::Listener listener) + : ForwardingListener(&listener) {} + + void notifyOperationErased(mlir::Operation* op) override { + for (mlir::Value result : op->getResults()) + map.erase(result); + } + + void notifyBlockErased(mlir::Block* block) override { + for (mlir::BlockArgument arg : block->getArguments()) + map.erase(arg); + } +}; + +template +class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener { +public: + llvm::DenseMap map; + + NotErasableValueMap(mlir::OpBuilder::Listener listener) + : ForwardingListener(&listener) {} + + void notifyOperationErased(mlir::Operation* op) override { + for (mlir::Value result : op->getResults()) + assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result)); + } + + void notifyBlockErased(mlir::Block* block) override { + for (mlir::BlockArgument arg : block->getArguments()) + assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg)); + } +}; diff --git a/src/PIM/Compiler/CMakeLists.txt b/src/PIM/Compiler/CMakeLists.txt new file mode 100644 index 0000000..fe4c290 --- /dev/null +++ b/src/PIM/Compiler/CMakeLists.txt @@ -0,0 +1,44 @@ +get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS) + +add_onnx_mlir_library(OMPimCompilerOptions + PimCompilerOptions.cpp + + EXCLUDE_FROM_OM_LIBS + + INCLUDE_DIRS PRIVATE + ${PIM_SRC_ROOT} + ${PIM_BIN_ROOT} + ${PIM_ONNX_MLIR_SRC_ROOT} + ${PIM_ONNX_MLIR_BIN_ROOT} + + LINK_LIBS PUBLIC + ${OMLibs} + OMCompilerOptions + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_ONNX_MLIR_SRC_ROOT} + ${PIM_ONNX_MLIR_BIN_ROOT} +) + +add_onnx_mlir_library(OMPimCompilerUtils + PimCompilerUtils.cpp + PimCodeGen.cpp + + EXCLUDE_FROM_OM_LIBS + + INCLUDE_DIRS PRIVATE + ${PIM_SRC_ROOT} + ${PIM_BIN_ROOT} + ${PIM_ONNX_MLIR_SRC_ROOT} + ${PIM_ONNX_MLIR_BIN_ROOT} + + LINK_LIBS PUBLIC + ${OMLibs} + OMCompilerUtils + OMPimCompilerOptions + OMCompilerPasses + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_ONNX_MLIR_SRC_ROOT} + ${PIM_ONNX_MLIR_BIN_ROOT} +) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp new file mode 100644 index 0000000..16438d7 --- /dev/null +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -0,0 +1,704 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" +#include "Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerPasses.hpp" +#include "src/Compiler/CompilerUtils.hpp" + +namespace onnx_mlir { + +MemEntry* PimMemory::gatherMemEntry(Value value) { + auto type = cast(value.getType()); + assert("Only static shape is supported" && type.hasStaticShape()); + size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8; + MemEntry memEntry = {0, allocSize}; + return &memEntries.emplace_back(memEntry, value).first; +} + +void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) { + memEntry.address = firstAvailableAddress; + firstAvailableAddress += memEntry.size; + // Alignment + if (size_t remainder = firstAvailableAddress % minAlignment) + firstAvailableAddress += minAlignment - remainder; + + globalMemEntriesMap[value] = memEntry; +} + +void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { + // More than one SSA value per single global constant: + // Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times + // Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others + llvm::SmallDenseMap globalConstants; + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { + if (!getGlobalOp->hasAttr("weightAlways")) { + auto globalMemrefOp = moduleOp.lookupSymbol(getGlobalOp.getName()); + auto iter = globalConstants.find(globalMemrefOp); + if (iter == globalConstants.end()) + globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); + else { + MemEntry memEntry = *iter->second; + globalMemEntriesMap[getGlobalOp] = memEntry; + } + } + }); + + for (Value arg : funcOp.getArguments()) + gatherMemEntry(arg); + + allocateCore(funcOp); +} + +void PimMemory::allocateCore(Operation* op) { + op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); + + llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); + for (auto& [memEntry, value] : memEntries) + allocateMemoryForValue(value, memEntry); +} + +MemEntry PimMemory::getMemEntry(Value value) const { + auto iter = globalMemEntriesMap.find(value); + assert("Missing memEntry for value" && iter != globalMemEntriesMap.end()); + return iter->second; +} + +PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { + return deviceMem.try_emplace(id, memEntriesMap).first->second; +} + +size_t PimAcceleratorMemory::getValueAddress(Value value) const { + while (true) { + auto definingOp = value.getDefiningOp(); + if (!definingOp) + break; + if (auto dpsDefiningOp = dyn_cast(definingOp)) { + OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); + if (!tiedOperand) + break; + value = tiedOperand->get(); + } + else if (auto subviewDefiningOp = dyn_cast(definingOp)) { + auto source = subviewDefiningOp.getSource(); + auto srcShape = source.getType().getShape(); + auto subviewOffsets = subviewDefiningOp.getStaticOffsets(); + auto subviewSizes = subviewDefiningOp.getStaticSizes(); + auto subviewStrides = subviewDefiningOp.getStaticStrides(); + assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides)); + value = source; + } + else + break; + } + return memEntriesMap.at(value).address; +} + +llvm::json::Object PimCodeGen::createSetImmediate(size_t targetRegister, size_t immediate) { + llvm::json::Object returnValue; + returnValue["op"] = "sldi"; + returnValue["rd"] = targetRegister; + returnValue["imm"] = immediate; + return returnValue; +} + +llvm::json::Object PimCodeGen::createEmptyOffset() { + llvm::json::Object returnValue; + returnValue["offset_select"] = 0; + returnValue["offset_value"] = 0; + return returnValue; +} + +void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) { + llvm::json::Object setRegisterJson = createSetImmediate(registerNumber, immediate); + coreFileStream << llvm::json::Value(std::move(setRegisterJson)) << ','; +} + +void PimCodeGen::createRd(size_t rdAddress, size_t rdOffset) { + // rd on register 0 + genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); +} + +void PimCodeGen::createRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) { + // rd on register 0 + genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); + // rs1 on register 1 + genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset); +} + +void PimCodeGen::createRdRs1Rs2( + size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) { + // rd on register 0 + genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); + // rs1 on register 1 + genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset); + // rs2 on register 2 + genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset); +} + +void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) { + auto deviceDst = loadOp.getDeviceDst(); + auto hostSrc = loadOp.getHostSrc(); + auto deviceDstOffset = loadOp.getDeviceDstOffset(); + auto hostSrcOffset = loadOp.getHostSrcOffset(); + auto size = loadOp.getSize(); + + auto deviceDstAlloc = memory.getValueAddress(deviceDst); + auto hostSrcAlloc = memory.getValueAddress(hostSrc); + + // Set load rd register (reg 0) + createRdRs1(deviceDstAlloc, deviceDstOffset, hostSrcAlloc, hostSrcOffset); + + llvm::json::Object loadOpJson; + loadOpJson["op"] = "ld"; + loadOpJson["rd"] = 0; + loadOpJson["rs1"] = 1; + loadOpJson["size"] = size; + loadOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(loadOpJson)) << ','; +} + +void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) { + auto hostDst = storeOp.getHostDst(); + auto deviceSrc = storeOp.getDeviceSrc(); + auto hostDstOffset = storeOp.getHostDstOffset(); + auto deviceSrcOffset = storeOp.getDeviceSrcOffset(); + auto size = storeOp.getSize(); + + auto deviceSrcAlloc = memory.getValueAddress(deviceSrc); + auto hostDstAlloc = memory.getValueAddress(hostDst); + + // Set load rd register (reg 0) + createRdRs1(hostDstAlloc, hostDstOffset, deviceSrcAlloc, deviceSrcOffset); + + llvm::json::Object storeOpJson; + storeOpJson["op"] = "st"; + storeOpJson["rd"] = 0; + storeOpJson["rs1"] = 1; + storeOpJson["size"] = size; + storeOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(storeOpJson)) << ','; +} + +template +void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) { + auto outBufAlloc = memory.getValueAddress(mvmLikeOp.getOutBuf()); + auto vectorAlloc = memory.getValueAddress(mvmLikeOp.getVectorInput()); + + createRdRs1(outBufAlloc, 0, vectorAlloc, 0); + + llvm::json::Object mvmOpJson; + mvmOpJson["op"] = "mvmul"; + mvmOpJson["rd"] = 0; + mvmOpJson["rs1"] = 1; + mvmOpJson["group"] = mvmId; + mvmOpJson["relu"] = 0; + mvmOpJson["mbiw"] = 8; + + coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ','; + + // TODO: save weights somewhere (if transposeMatrix=true, then transpose the + // weight matrix) +} + +void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) { + + auto outBuff = memory.getValueAddress(applyFiltersOp.getOutBuf()); + auto inBuff = memory.getValueAddress(applyFiltersOp.getInput()); + auto accumBuff = memory.getValueAddress(applyFiltersOp.getAccumBuf()); + + // Get weight indices from the operation attribute. + auto weightIndices = applyFiltersOp.getWeightIndices(); + + // Get shape of the input tensor. + auto inputType = cast(applyFiltersOp.getInput().getType()); + auto outputType = cast(applyFiltersOp.getOutBuf().getType()); + auto in_shape = inputType.getShape(); + auto out_shape = outputType.getShape(); + + // Extract the relevant dimensions. + size_t in_channels = in_shape[1]; // Number of input channels. + size_t out_channels = out_shape[1]; // Number of output channels. + + size_t dim2 = in_shape.size() > 2 ? in_shape[2] : 1; // Image width. + size_t dim3 = in_shape.size() > 3 ? in_shape[3] : 1; // Image height. + + // Iterate through pixels. + for (size_t out_y = 0; out_y < dim3; out_y++) { + for (size_t out_x = 0; out_x < dim2; out_x++) { + + // For each crossbar, perform the MVMUL operation. + size_t weightIndex = 0; + for (Attribute weight : weightIndices) { + + // -------------------------------------- + // --- STEP 1: Perform MVUL operation --- + // -------------------------------------- + + // Get the weight matrix ID for this position. + auto weightId = cast(weight).getInt(); + + size_t xKer = cast(applyFiltersOp.getXKernelPositions()[weightIndex]).getInt(); + size_t yKer = cast(applyFiltersOp.getYKernelPositions()[weightIndex]).getInt(); + + weightIndex++; + + if (out_x + xKer >= dim2 || out_y + yKer >= dim3) + continue; + + // Calculate the offset for the input (and output) tensor. + size_t output_offset = (out_y * dim2 + out_x) * 32 * out_channels; + size_t input_offset = ((out_y + yKer) * dim2 + (out_x + xKer)) * 32 * in_channels; + + // Read from the input tensor and store the partial result in the + // accumulator buffer, if this is not the first weight matrix. + + // Note that rs1 is the input tensor, and rd is the output tensor. + // TODO: This order of arguments is confusing, check if the correct + // order is being used in the WMVUL operation. The order below is + // correct. + if (weightIndices[0] != weight) { + createRdRs1(accumBuff, 0, inBuff, input_offset); + } + else { + // Otherwise store directly in the output buffer. + createRdRs1(outBuff, output_offset, inBuff, input_offset); + } + + // Create the MVMUL JSON object + llvm::json::Object mvmOpJson; + mvmOpJson["op"] = "mvmul"; + mvmOpJson["rd"] = 0; + mvmOpJson["rs1"] = 1; + mvmOpJson["group"] = weightId; + mvmOpJson["relu"] = 0; + mvmOpJson["mbiw"] = 8; + + // Write the JSON to the output stream + coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ','; + + // -------------------------------------- + // --- STEP 2: Perform VADD operation --- + // -------------------------------------- + + // If this is the first weight matrix, we don't need to perform a VADD. + if (weightIndices[0] == weight) + continue; + + // We now need to sum the value in the accumulator buffer with the value + // in the output buffer, and store the result in the output buffer. + createRdRs1Rs2(outBuff, output_offset, accumBuff, 0, outBuff, output_offset); + + llvm::json::Object vaddOpJson; + vaddOpJson["op"] = "vvadd"; + vaddOpJson["rd"] = 0; + vaddOpJson["rs1"] = 1; + vaddOpJson["rs2"] = 2; + vaddOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(vaddOpJson)) << ','; + } + } + } +} + +void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) { + auto outBufAlloc = memory.getValueAddress(vaddOp.getOutBuf()); + auto rs1BufferOp = memory.getValueAddress(vaddOp.getA()); + auto rs2BufferOp = memory.getValueAddress(vaddOp.getB()); + + createRdRs1Rs2(outBufAlloc, 0, rs1BufferOp, 0, rs2BufferOp, 0); + + // Get the size of the output buffer. + auto outputType = cast(vaddOp.getOutBuf().getType()); + auto out_shape = outputType.getShape(); + + // Multiply all dimension lengths to get the total number of elements. + size_t totalElements = 1; + for (size_t i = 0; i < out_shape.size(); i++) + totalElements *= out_shape[i]; + auto elementSize = vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8; + + llvm::json::Object mvmOpJson; + mvmOpJson["op"] = "vvadd"; + mvmOpJson["rd"] = 0; + mvmOpJson["rs1"] = 1; + mvmOpJson["rs2"] = 2; + mvmOpJson["offset"] = createEmptyOffset(); + mvmOpJson["len"] = totalElements * elementSize; + + coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ','; +} + +void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) { + + auto outBufAlloc = memory.getValueAddress(vmaxOp.getOutBuf()); + auto rs1BufferOp = memory.getValueAddress(vmaxOp.getA()); + auto rs2BufferOp = memory.getValueAddress(vmaxOp.getB()); + + createRdRs1Rs2(outBufAlloc, 0, rs1BufferOp, 0, rs2BufferOp, 0); + + llvm::json::Object mvmOpJson; + mvmOpJson["op"] = "vvmax"; + mvmOpJson["rd"] = 0; + mvmOpJson["rs1"] = 1; + mvmOpJson["rs2"] = 2; + mvmOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ','; +} + +void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) { + auto outBufAlloc = memory.getValueAddress(vreluOp.getOutBuf()); + auto rs1BufferOp = memory.getValueAddress(vreluOp.getA()); + + createRdRs1(outBufAlloc, 0, rs1BufferOp, 0); + + llvm::json::Object mvmOpJson; + mvmOpJson["op"] = "vrelu"; + mvmOpJson["rd"] = 0; + mvmOpJson["rs1"] = 1; + mvmOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ','; +} + +void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) { + + auto destAlloc = memory.getValueAddress(receiveOp.getDst()); + + createRd(destAlloc, /* dest_offset = */ 0); + + llvm::json::Object recvOpJson; + recvOpJson["op"] = "recv"; + recvOpJson["rd"] = 0; + recvOpJson["core"] = receiveOp.getSrcCoreId(); + recvOpJson["size"] = receiveOp.getSize(); + recvOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(recvOpJson)) << ','; +} + +void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) { + + auto srcAlloc = memory.getValueAddress(sendOp.getSrc()); + + // Technically a RS1 register, but its just a name.. + createRd(srcAlloc, /* dest_offset = */ 0); + + llvm::json::Object sendOpJson; + sendOpJson["op"] = "send"; + sendOpJson["rd"] = 0; + sendOpJson["core"] = sendOp.getTargetCoreId(); + sendOpJson["size"] = sendOp.getSize(); + sendOpJson["offset"] = createEmptyOffset(); + + coreFileStream << llvm::json::Value(std::move(sendOpJson)) << ','; +} + +size_t getMatrixSize(ShapedType matrixShape) { + if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4) + assert(false && "Unsupported matrix shape"); + return std::max(matrixShape.getDimSize(0), matrixShape.getDimSize(1)); +} + +std::string getMemorySizeAsString(size_t size) { + if (size > 1024 * 1024 * 1024) + return std::to_string(size / 1024 / 1024 / 1024) + " GB"; + if (size > 1024 * 1024) + return std::to_string(size / 1024 / 1024) + " MB"; + if (size > 1024) + return std::to_string(size / 1024) + " KB"; + return std::to_string(size) + " Bytes"; +} + +int compileModuleToPIMJSON(const OwningOpRef& moduleOpRef, std::string& outputDirPath) { + ModuleOp moduleOp = moduleOpRef.get(); + + if (pimEmissionTarget != EmitPimCodegen) { + moduleOp.dump(); + return CompilerSuccess; + } + + if (!outputDirPath.empty()) { + if (auto error = llvm::sys::fs::create_directory(outputDirPath)) { + llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n'; + return InvalidOutputFileAccess; + } + } + + // For each core, specify the number of crossbar per array group + // This implementation always assigns one crossbar per group + llvm::json::Object xbarsPerArrayGroup; + + auto funcOps = moduleOp.getOps(); + assert(!funcOps.empty() && "No function found in the module"); + auto funcOp = *funcOps.begin(); + + PimAcceleratorMemory memory; + memory.hostMem.allocateHost(moduleOp, funcOp); + + // Write memory binary file + auto memoryFilePath = outputDirPath + "/memory.bin"; + std::error_code errorCode; + llvm::raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, llvm::sys::fs::OF_None); + if (errorCode) { + llvm::errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n'; + return InvalidOutputFileAccess; + } + // Zero-initialized buffer + std::vector memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); + // Write global values at their allocated addresses + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { + if (getGlobalOp->hasAttr("weightAlways")) + return; + auto globalOp = moduleOp.lookupSymbol(getGlobalOp.getName()); + if (!globalOp) + return; + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) + return; + auto denseAttr = dyn_cast(*initialValue); + if (!denseAttr) + return; + auto memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult()); + auto rawData = denseAttr.getRawData(); + std::memcpy(memoryBuffer.data() + memEntry.address, rawData.data(), std::min(rawData.size(), memEntry.size)); + }); + memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size()); + memoryFileStream.close(); + + size_t coreCount = 0; + for (auto coreOp : funcOp.getOps()) { + auto coreId = coreOp.getCoreId(); + coreCount++; + + std::error_code errorCode; + auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json"; + llvm::raw_fd_ostream coreFileStream(outputCorePath, errorCode); + if (errorCode) { + llvm::errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n'; + return InvalidOutputFileAccess; + } + + coreFileStream << '['; + auto coreNameString = "core" + std::to_string(coreId); + + PimCodeGen coreCodeGen(memory, coreFileStream); + memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); + + size_t processedOperations = 0; + for (auto& op : coreOp.getBody().front()) { + if (isa(op)) + continue; + if (isa(op)) + continue; + if (auto loadOp = dyn_cast(op)) { + coreCodeGen.codeGenLoadOp(loadOp); + } + else if (auto storeOp = dyn_cast(op)) { + coreCodeGen.codeGenStoreOp(storeOp); + } + else if (auto vmmOp = dyn_cast(op)) { + coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true); + } + else if (auto mvmOp = dyn_cast(op)) { + coreCodeGen.codeGenMVMLikeOp(mvmOp.getWeightIndex(), mvmOp, false); + } + else if (auto applyFiltersOp = dyn_cast(op)) { + coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); + } + else if (auto vaddOp = dyn_cast(op)) { + coreCodeGen.codeGenVAddOp(vaddOp); + } + else if (auto vmaxOp = dyn_cast(op)) { + coreCodeGen.codeGenVMaxOp(vmaxOp); + } + else if (auto vreluOp = dyn_cast(op)) { + coreCodeGen.codeGenVReluOp(vreluOp); + } + else if (auto receiveOp = dyn_cast(op)) { + coreCodeGen.codeGenReceiveOp(receiveOp); + } + else if (auto sendOp = dyn_cast(op)) { + coreCodeGen.codeGenSendOp(sendOp); + } + else if (auto sumOp = dyn_cast(op)) { + // TODO: Implement somehow? + op.emitWarning("Sum operation is not supported"); + continue; + } + else if (auto vsDivOp = dyn_cast(op)) { + // TODO: Implement somehow? + op.emitWarning("VSDiv operation is not supported"); + continue; + } + else if (auto vexpOp = dyn_cast(op)) { + // TODO: Implement somehow? + op.emitWarning("VExp operation is not supported"); + continue; + } + else if (isa(op)) { + continue; + } + else { + op.emitError("Unsupported codegen for this operation"); + op.dump(); + return CompilerFailure; + } + processedOperations++; + } + assert(processedOperations > 0); + // Remove trailing comma + coreFileStream.seek(coreFileStream.tell() - 1); + coreFileStream << ']'; + coreFileStream.close(); + + // Create output directory for this core's crossbar weights + auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); + if (auto error = llvm::sys::fs::create_directory(coreWeightsDirPath)) { + llvm::errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; + return InvalidOutputFileAccess; + } + + int64_t xbarSize = crossbarSize.getValue(); + size_t weightIndex = 0; + llvm::json::Array xbarsPerGroup; + for (auto weight : coreOp.getWeights()) { + xbarsPerGroup.push_back(weightIndex); + auto getGlobalOp = weight.getDefiningOp(); + if (!getGlobalOp) { + coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex)); + weightIndex++; + continue; + } + + auto globalOp = SymbolTable::lookupNearestSymbolFrom(moduleOp, getGlobalOp.getNameAttr()); + if (!globalOp) { + coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex)); + weightIndex++; + continue; + } + + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) { + coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex)); + weightIndex++; + continue; + } + + auto denseAttr = dyn_cast(*initialValue); + if (!denseAttr) { + coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex)); + weightIndex++; + continue; + } + + auto type = denseAttr.getType(); + auto shape = type.getShape(); + assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); + int64_t numRows = shape[0]; + int64_t numCols = shape[1]; + assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); + + auto elementType = type.getElementType(); + size_t elementByteWidth = elementType.getIntOrFloatBitWidth() / 8; + + // Write crossbar weights as binary, padded to crossbarSize x crossbarSize + auto weightFilePath = coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin"; + llvm::raw_fd_ostream weightFileStream(weightFilePath, errorCode, llvm::sys::fs::OF_None); + if (errorCode) { + llvm::errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; + return InvalidOutputFileAccess; + } + + uint64_t zero = 0; + for (int64_t row = 0; row < xbarSize; row++) { + for (int64_t col = 0; col < xbarSize; col++) { + if (row < numRows && col < numCols) { + int64_t index = row * numCols + col; + APInt bits = denseAttr.getValues()[index].bitcastToAPInt(); + uint64_t word = bits.getZExtValue(); + weightFileStream.write(reinterpret_cast(&word), elementByteWidth); + } + else { + weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); + } + } + } + + weightFileStream.close(); + weightIndex++; + } + xbarsPerArrayGroup[coreNameString] = std::move(xbarsPerGroup); + } + + // Step 3: Write configuration to JSON + llvm::json::Object configJson; + configJson["core_cnt"] = coreCount; + + // TODO: Should this be based on the floating point type used in the model? + //// The 2 following values determine the bitwidth of the vectors' elements: + //// bitwidth = adc_count * cell_precision + // Number of ADC for MVM units + configJson["adc_count"] = 16; + // Bit precision of each ADC + configJson["cell_precision"] = 2; + + //// Crossbar configuration + configJson["xbar_array_count"] = crossbarCountInCore.getValue(); + configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()}; + + // Store the crossbar sizes + configJson["array_group_map"] = std::move(xbarsPerArrayGroup); + + // Store the memory layout of inputs and outputs + llvm::json::Array inputsAddresses; + for (BlockArgument input : funcOp.getArguments()) + inputsAddresses.push_back(memory.getValueAddress(input)); + configJson["inputs_addresses"] = std::move(inputsAddresses); + llvm::json::Array outputsAddresses; + for (func::ReturnOp returnOp : funcOp.getOps()) + for (Value output : returnOp.getOperands()) + outputsAddresses.push_back(memory.getValueAddress(output)); + configJson["outputs_addresses"] = std::move(outputsAddresses); + + // Step 4: Write config JSON + std::string openOutputErrorMsg; + auto configPath = outputDirPath + "/config.json"; + std::error_code EC; + llvm::raw_fd_ostream jsonOS(configPath, EC); + if (EC) { + llvm::errs() << "Error while opening config file: " << EC.message() << '\n'; + return InvalidOutputFileAccess; + } + jsonOS << llvm::json::Value(std::move(configJson)) << '\n'; + jsonOS.close(); + + showCompilePhase("Code generated into " + configPath); + + return CompilerSuccess; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp new file mode 100644 index 0000000..b6cc382 --- /dev/null +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include "llvm/Support/JSON.h" + +#include "Common/ValueMap.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerPasses.hpp" + +namespace onnx_mlir { + +struct MemEntry { + size_t address; + size_t size; +}; + +class PimMemory { + SmallVector, 32> memEntries; + llvm::SmallDenseMap& globalMemEntriesMap; + + size_t maxSize = 0; // 0 for unbounded memory + size_t startAddress = 0; + size_t minAlignment = 4; + size_t firstAvailableAddress = 0; + + MemEntry* gatherMemEntry(Value value); + void allocateMemoryForValue(Value value, MemEntry& memEntry); + +public: + PimMemory(llvm::SmallDenseMap& globalMemEntriesMap) + : globalMemEntriesMap(globalMemEntriesMap) {} + + void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp); + void allocateCore(Operation* op); + + size_t getFirstAvailableAddress() const { return firstAvailableAddress; } + MemEntry getMemEntry(Value value) const ; +}; + +class PimAcceleratorMemory { +public: + llvm::SmallDenseMap memEntriesMap; + PimMemory hostMem; + +private: + llvm::SmallDenseMap deviceMem; + +public: + PimAcceleratorMemory() + : hostMem(memEntriesMap) {} + + PimMemory getOrCreateDeviceMem(size_t id); + + size_t getValueAddress(Value value) const; +}; + +class PimCodeGen { + PimAcceleratorMemory& memory; + llvm::raw_fd_ostream& coreFileStream; + +public: + PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson) + : memory(memory), coreFileStream(coreJson) {} + + llvm::json::Object createSetImmediate(size_t targetRegister, size_t immediate); + llvm::json::Object createEmptyOffset(); + + void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate); + + void createRd(size_t rdAddress, size_t rdOffset); + void createRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset); + void createRdRs1Rs2( + size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset); + + void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp); + + void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp); + + template + void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix); + + void codeGenReceiveOp(pim::PimReceiveOp receiveOp); + + void codeGenSendOp(pim::PimSendOp sendOp); + + void codeGenVAddOp(pim::PimVAddOp vaddOp); + + void codeGenVMaxOp(pim::PimVMaxOp vmaxOp); + + void codeGenVReluOp(pim::PimVReluOp vreluOp); + + void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp); +}; + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp new file mode 100644 index 0000000..bddd179 --- /dev/null +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------------- PimCompilerOptions.cpp --------------------===// +// +// Copyright 2022 The IBM Research Authors. +// +// ============================================================================= +// +// Compiler Options for PIM +// +//===----------------------------------------------------------------------===// +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + +#define DEBUG_TYPE "PimCompilerOptions" + +namespace onnx_mlir { + +llvm::cl::opt pimEmissionTarget( + llvm::cl::desc("[Optional] Choose PIM-related target to emit " + "(once selected it will cancel the other targets):"), + llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")), + llvm::cl::values(clEnumVal(EmitPim, "Lower model to PIM IR")), + llvm::cl::values( + clEnumVal(EmitPimBufferized, "Lower model to PIM IR and bufferize it")), + llvm::cl::values(clEnumVal(EmitPimCodegen, "Lower model to PIM IR and " + "generate code for PIM")), + llvm::cl::init(EmitPimCodegen), llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt pimOnlyCodegen("pim-only-codegen", + llvm::cl::desc("Only generate code for PIM (assume input is already in " + "bufferized PIM IR)"), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt useExperimentalConvImpl("use-experimental-conv-impl", + llvm::cl::desc("Use experimental implementation for convolution"), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt crossbarSize("crossbar-size", + llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); + +llvm::cl::opt crossbarCountInCore("crossbar-count", + llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2)); + +llvm::cl::opt coresCount("core-count", + llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum " + "amount of cores."), + llvm::cl::init(-1)); + +llvm::cl::opt ignoreConcatError("ignore-concat-error", + llvm::cl::desc( + "Ignore ConcatOp corner case: do not assert and do a simplification"), + llvm::cl::init(false)); + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp new file mode 100644 index 0000000..8b2a51b --- /dev/null +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "llvm/Support/CommandLine.h" + +#define INSTRUMENTSTAGE_ENUM_PIM + +#define INSTRUMENTSTAGE_CL_ENUM_PIM + +#define PROFILEIR_CL_ENUM_PIM + +#define OPTREPORT_ENUM_PIM + +#define OPTREPORT_CL_ENUM_PIM + +namespace onnx_mlir { +typedef enum { + EmitSpatial = 0, + EmitPim = 1, + EmitPimBufferized = 2, + EmitPimCodegen = 3 +} PimEmissionTargetType; + +extern llvm::cl::OptionCategory OnnxMlirOptions; +extern llvm::cl::opt pimEmissionTarget; + +extern llvm::cl::opt pimOnlyCodegen; +extern llvm::cl::opt useExperimentalConvImpl; +extern llvm::cl::opt exportCrossbarWeights; + +extern llvm::cl::opt crossbarSize; +extern llvm::cl::opt crossbarCountInCore; +extern llvm::cl::opt coresCount; + +// This option, by default set to false, will ignore an error when resolving a +// specific tiles of the operands of a concat. This specific case is when the +// wanted tile is generated by two separate operands of the concat. If this is +// set to false, this corner case will assert an error. If this is set to true, +// a simplification is performed and only the tile from the first operand is +// taken. +extern llvm::cl::opt ignoreConcatError; + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp new file mode 100644 index 0000000..6186bd4 --- /dev/null +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -0,0 +1,56 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/Passes.h" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerPasses.hpp" + +#include "llvm/Support/JSON.h" + +#include +#include + +#define DEBUG_TYPE "PimCompilerUtils" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { + +void addPassesPim(OwningOpRef& module, + PassManager& pm, + EmissionTargetType& emissionTarget, + std::string outputNameNoExt) { + + if (pimOnlyCodegen) { + // Skip all the lowering passes and directly generate code for PIM. + return; + } + + if (emissionTarget >= EmitONNXIR) + addONNXToMLIRPasses(pm, /*target CPU*/ false); + + if (pimEmissionTarget >= EmitSpatial) { + pm.addPass(createONNXToSpatialPass()); + // pm.addPass(createCountInstructionPass()); + pm.addPass(createMessagePass("ONNX lowered to SPATIAL")); + } + + if (pimEmissionTarget >= EmitPim) { + pm.addPass(createSpatialToPIMPass()); + // pm.addPass(createCountInstructionPass()); + pm.addPass(createMessagePass("SPATIAL lowered to PIM")); + } + + if (pimEmissionTarget >= EmitPimBufferized) { + pm.addPass(createBufferizePimPass()); + // pm.addPass(createCountInstructionPass()); + pm.addPass(createMessagePass("PIM bufferized")); + } +} + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCompilerUtils.hpp b/src/PIM/Compiler/PimCompilerUtils.hpp new file mode 100644 index 0000000..76581be --- /dev/null +++ b/src/PIM/Compiler/PimCompilerUtils.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + +#include "onnx-mlir/Compiler/OMCompilerTypes.h" + +namespace onnx_mlir { + +void addPassesPim(mlir::OwningOpRef& module, + mlir::PassManager& pm, + EmissionTargetType& emissionTarget, + std::string outputNameNoExt); + +int compileModuleToPIMJSON(const mlir::OwningOpRef& moduleOpRef, + std::string& outputDirName); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/CMakeLists.txt b/src/PIM/Conversion/CMakeLists.txt new file mode 100644 index 0000000..27c58b8 --- /dev/null +++ b/src/PIM/Conversion/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(ONNXToSpatial) +add_subdirectory(SpatialToGraphviz) +add_subdirectory(SpatialToPIM) \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt new file mode 100644 index 0000000..c0c48ac --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -0,0 +1,34 @@ +set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td) +mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") +add_public_tablegen_target(ONNXToSpatialIncGen) + +add_onnx_mlir_library(OMONNXToSpatial + Math/Gemm.cpp + Math/Conv.cpp + Math/ExperimentalConv.cpp + Math/ExperimentalGemm.cpp + NN/Pooling.cpp + NN/ExperimentalPooling.cpp + NN/ReduceMean.cpp + Tensor/ONNXConcatToTensorConcat.cpp + Tensor/RemoveUnusedHelperOps.cpp + Utils/SpatialReducer.cpp + Utils/WeightSubdivider.cpp + Utils/AnnotateReplication.cpp + ONNXToSpatialPass.hpp + ONNXToSpatialPass.cpp + ONNXToSpatialCommon.cpp + + DEPENDS + ONNXToSpatialIncGen + + LINK_LIBS PUBLIC + OMCompilerOptions + OMPimCompilerOptions + OMONNXOps + SpatialOps + OMPIMCommon + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_INCLUDE_PATH} +) diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp new file mode 100644 index 0000000..255ca55 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -0,0 +1,624 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" +#include + +#include +#include +#include + +using namespace mlir; +using namespace std; + +namespace onnx_mlir { + +// NOTE: +// This might be useful to re-implement this considering for loops. +// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; + +/** + * @brief A momentary representation of a core, to be used within the tiling of + * a convolution operation. + */ +class Core { +public: + Core(const size_t coreId, ConversionPatternRewriter &rewriter) + : coreId(coreId), rewriter(rewriter) {} + + /** + * @brief Add a MVM operation to the core. + * + * @param inputTile The input tile to the MVM operation. + * @param xbarIndex The index of the crossbar weight to use. + * @param outputTileId The id of the output tile. + * @param mvmOutType The result's shape. + * @return Value The result of the MVM operation. + */ + Value addMVM( + Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { + // Use the inputTile as the reference location for the MVM operation. + Location loc = inputTile.getLoc(); + + // Move the insertion point to the end of the block. + rewriter.setInsertionPointToEnd(block.get()); + + // Add the inputTile to the block arguments, and to the operands. + Value operand = operandMap.lookupOrNull(inputTile); + if (not operand) { + operand = block->addArgument(inputTile.getType(), loc); + operands.push_back(inputTile); + operandMap.map(inputTile, operand); + } + + // TODO: Compute the output type using the matrix, and check if `mvmOutType` + // is correct. + + // Construct the MVM operation + Value result = rewriter.create( + loc, mvmOutType, xbarIndex, operand); + + // Since we are within the same core and no computation can happen in + // paralllel, we can just apply a linear reduction in case we have multiple + // MVM operations for the same outputTile. + auto lastMVM = outputTileToMVM.find(outputTileId); + + // If an entry for this outputTile already exists, apply reduction. + if (lastMVM != outputTileToMVM.end()) { + // MVM results should have the same type for reduction. + assert(lastMVM->second.getType() == result.getType()); + result = rewriter.create( + loc, mvmOutType, lastMVM->second, result); + } + + outputTileToMVM[outputTileId] = result; + return result; + } + + /** + * @brief Mark a result as remappable, and return a shared pointer to it. + * + * This function marks a result as remappable, and returns a shared pointer to + * it. We need to keep track of these values to generate the YieldOp at a + * later stage. + * + * @param result A result to track, for later remapping. + * @return shared_ptr A shared pointer to the result. + */ + shared_ptr makeResultRemappable(Value result) { + // Verify that the result is present in the block. + assert(result.getDefiningOp()->getBlock() == block.get()); + + shared_ptr remappableResult = make_shared(result); + + resultsToRemap.push_back(remappableResult); + results.push_back(result); + + return remappableResult; + } + + /** + * @brief Add a remappable operand to the core, to merge partial results + * inter-core. + * + * @param remappableOperand The operand to add. + * @return Value The block argument representing the operand. + */ + Value addRemappableOperand(std::shared_ptr operand) { + // Check that the operand is not already there. + assert(not operandMap.contains(*operand)); + + Value argument = block->addArgument(operand->getType(), operand->getLoc()); + remappableOperands.push_back(operand); + return argument; + } + + /** + * @brief Generate a spatial::SpatWeightedCompute operation from the core. + * + * @param loc The location of the operation. + * @return spatial::SpatWeightedCompute + */ + spatial::SpatWeightedCompute createWComputeOp(Location loc) { + // Get the shape of the results. + SmallVector resultTypes; + for (const auto &value : results) { + resultTypes.push_back(value.getType()); + } + + // Create the WComputeOp, with non-remappable operands only. + wcomputeOp = rewriter.create( + loc, resultTypes, xbarWeights, operands); + + // Add the body to the WComputeOp. + Block *releasedBlock = block.release(); + wcomputeOp.getBody().push_back(releasedBlock); + + // Add the `yieldOp` at the end, with the results. + rewriter.setInsertionPointToEnd(releasedBlock); + rewriter.create(loc, results); + + return wcomputeOp; + } + + /** + * @brief Remap the results to the WComputeOp results. + */ + void remapResults() { + // Remap all the results to the WComputeOp results. + assert(resultsToRemap.size() == wcomputeOp->getNumResults()); + for (size_t i = 0; i < resultsToRemap.size(); i++) { + *resultsToRemap[i] = wcomputeOp.getResult(i); + } + } + + void addRemappedOperands() { + // Insert the remappableOperands (which were remapped in + // `addRemappableOperand` of another Core) + for (auto remappedValue : remappableOperands) { + wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); + } + + // Update the wcomputeOp operandSegmentSize + incrementWeightedComputeInputsSegmentSize( + wcomputeOp, static_cast(remappableOperands.size())); + } + + size_t addXbarWeight(Value weight) { + assert(!isXbarsFull()); + xbarWeights.push_back(weight); + return xbarWeights.size() - 1; + } + + bool isXbarsFull() { + assert(xbarWeights.size() <= crossbarCountInCore); + return xbarWeights.size() == crossbarCountInCore; + } + + bool isCoreEmpty() { return block->empty(); } + + void dump() { + // Print the coreId + llvm::outs() << "Core " << coreId << ":\n"; + // Print the weights + llvm::outs() << "Xbar Weights:\n"; + for (auto weight : xbarWeights) { + weight.dump(); + } + // Print the operands + llvm::outs() << "Operands:\n"; + for (auto operand : operands) { + llvm::outs() << operand << "\n"; + } + + // Dump the body block + for (auto &op : block->getOperations()) { + op.dump(); + } + + // Print the results + llvm::outs() << "Results:\n"; + for (auto result : results) { + llvm::outs() << result << "\n"; + } + } + + const size_t coreId; + +private: + ConversionPatternRewriter &rewriter; + + // Should these be set instead? But I need to keep the order + vector operands; + vector> remappableOperands; + + vector results; + vector> resultsToRemap; + + // Maps from input tiles to the block operand + IRMapping operandMap; + + // Map from outputTileId to MVM operation producing it + unordered_map outputTileToMVM; + + vector xbarWeights; + + unique_ptr block = make_unique(); + + spatial::SpatWeightedCompute wcomputeOp; +}; + +struct ONNXConvOpTile : public OpConversionPattern { + ONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {} + + struct Producer_t { + Value value; + shared_ptr core; + }; + + LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, + ConversionPatternRewriter &rewriter) const final { + ShapedType xShape = mlir::cast(convAdaptor.getX().getType()); + ShapedType wShape = mlir::cast(convAdaptor.getW().getType()); + ShapedType bShape = mlir::cast(convAdaptor.getB().getType()); + ShapedType yShape = mlir::cast(conv.getY().getType()); + + size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y; + unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); + unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); + + auto padUnpackError = + unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) { + return rewriter.notifyMatchFailure(conv, padUnpackError.value()); + } + + // TODO: Pad value at beginning and end of each dimension could be + // different. We should handle this case. + + // MapOperations mapOperation = MapOperations::None; + // + // // If we have just one user, and it is an activation funcion (or more in + // // general a mapping operation) just inline it in the computeOps + // auto firstUserOp = *conv->getUsers().begin(); + // if (conv->hasOneUse()) { + // mapOperation = mlirOpToMapOperationEnum(firstUserOp); + // + // if (mapOperation == MapOperations::ONNXSoftmaxOp) { + // return rewriter.notifyMatchFailure( + // conv, "Softmax not supported as activation for convolutions."); + // } + // } + + size_t input_h = GET_IMAGE_HEIGHT(xShape); + size_t input_w = GET_IMAGE_WIDTH(xShape); + size_t output_h = GET_IMAGE_HEIGHT(yShape); + size_t output_w = GET_IMAGE_WIDTH(yShape); + size_t krn_h = GET_KERNEL_HEIGHT(wShape); + size_t krn_w = GET_KERNEL_WIDTH(wShape); + + Location loc = conv.getLoc(); + + size_t inputTileCount = + ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; + size_t outputTileCount = + ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); + size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; + + // Tile the input tensor + // Input tiles need to be indexed by: + // a. Channel Tile + // b. Pixel `x` position + // c. Pixel `y` position + // For example: inputTiles[channelTile][x][y] + // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) + SmallVector>> inputTiles(inputTileCount, + SmallVector>(input_w, SmallVector(input_h))); + + auto resolveErrorOpt = resolveImgInputTiles(convAdaptor.getX(), inputTiles, + inputTileCount, inputTileRemainder, input_h, input_h, rewriter); + if (resolveErrorOpt.has_value()) { + return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); + } + + SmallVector strides = + SmallVector(4, rewriter.getIndexAttr(1)); + SmallVector offsets = + SmallVector(4, rewriter.getIndexAttr(0)); + SmallVector sizes = SmallVector{ + rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + + // Tile the weight tensor + // Weight tiles need to be indexed by: + // a. Filter Tile + // b. Channel Tile + // c. Kernel `x` position + // d. Kernel `y` position + // For example: weightTiles[filterTile][channelTile][x][y] + // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) + SmallVector>>> weightTiles( + outputTileCount, + SmallVector>>(inputTileCount, + SmallVector>(krn_w, SmallVector(krn_h)))); + strides = SmallVector(4, rewriter.getIndexAttr(1)); + offsets = SmallVector(4, rewriter.getIndexAttr(0)); + sizes = {rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + for (size_t i = 0; i < outputTileCount; i++) { + if (i == outputTileCount - 1 && outputTileRemainder != 0) { + sizes[0] = rewriter.getIndexAttr(outputTileRemainder); + } + sizes[1] = rewriter.getIndexAttr(crossbarSize); + offsets[0] = rewriter.getIndexAttr(i * crossbarSize); + for (size_t j = 0; j < inputTileCount; j++) { + if (j == inputTileCount - 1 && inputTileRemainder != 0) { + sizes[1] = rewriter.getIndexAttr(inputTileRemainder); + } + for (size_t x = 0; x < krn_w; x++) { + for (size_t y = 0; y < krn_h; y++) { + offsets[1] = rewriter.getIndexAttr(j * crossbarSize); + offsets[2] = rewriter.getIndexAttr(x); + offsets[3] = rewriter.getIndexAttr(y); + weightTiles[i][j][x][y] = rewriter.create( + loc, convAdaptor.getW(), offsets, sizes, strides); + } + } + } + } + + /* Distribute the computation among many compute cores + * Try to compute in-core the computation for each output tile, and reduce + * over as few cores as possible + */ + + // Tile the output tensor + // Output tiles need to be indexed by: + // a. Filter Tile + // b. Pixel `x` position + // c. Pixel `y` position + // For example: outputTiles[filterTile][x][y] + // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) + SmallVector>>> outputTiles( + outputTileCount, + SmallVector>>( + output_w, SmallVector>(output_h, nullptr))); + + size_t replicationFactor; + if (!conv->hasAttr(REPLICATION_ATTR_NAME)) { + replicationFactor = 1; + } else { + replicationFactor = + conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); + } + // producers[outTile][out_x][out_y][producerIndex] + vector>>> producers = + vector>>>(outputTileCount, + vector>>(output_w, + vector>(output_h, vector()))); + + // Schedule in cores + size_t coreId = 0; + vector> curCores(replicationFactor); + for (size_t i = 0; i < replicationFactor; i++) { + curCores[i] = make_shared(coreId++, rewriter); + } + + vector> cores; + + const size_t replicationSliceSize = + ceilIntegerDivide(input_w, replicationFactor); + + for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { + for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { + + RankedTensorType mvmOutType = + RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, + bShape.getElementType()); + + for (size_t outTile = 0; outTile < outputTileCount; outTile++) { + + if (outTile == outputTileCount - 1 && outputTileRemainder != 0) { + mvmOutType = mvmOutType.clone( + {1, static_cast(outputTileRemainder), 1, 1}); + } + + for (size_t inTile = 0; inTile < inputTileCount; inTile++) { + + vector xbarIndexes(replicationFactor); + for (size_t i = 0; i < replicationFactor; i++) { + xbarIndexes[i] = curCores[i]->addXbarWeight( + weightTiles[outTile][inTile][krn_x][krn_y]); + } + + size_t out_x = 0; + for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { + size_t out_y = 0; + + // I use `replicationFactor` cores. I divide the input_w into + // `replicationFactor` slices, and each slice is distributed to a + // core. `coreIndex` is the index of the core that will be used + // for this slice + size_t coreIndex = in_x / replicationSliceSize; + assert(coreIndex < replicationFactor); + + for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { + // Adjust the input based on the kernel + int actual_in_x = in_x - ((int)krn_w / 2) + krn_x * dilation_x; + int actual_in_y = in_y - ((int)krn_h / 2) + krn_y * dilation_y; + + // Check if we are within the input image + if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, + actual_in_y, pad_x, pad_y) + .failed()) { + out_y++; + continue; + } + + size_t outTileId = + outTile * output_w * output_h + out_x * output_h + out_y; + auto mvm = curCores[coreIndex]->addMVM( + inputTiles[inTile][actual_in_x][actual_in_y], + xbarIndexes[coreIndex], outTileId, mvmOutType); + + producers[outTile][out_x][out_y].push_back( + {mvm, curCores[coreIndex]}); + + out_y++; + } + out_x++; + } + + // Computations for these crossbars are done, check if the cores + // crossbars are fully used. If full, swap with new core + for (size_t i = 0; i < replicationFactor; i++) { + if (curCores[i]->isXbarsFull()) { + cores.emplace_back(std::move(curCores[i])); + curCores[i] = make_shared(coreId++, rewriter); + } + } + } + } + } + } + + for (auto &curCore : curCores) { + if (curCore->isCoreEmpty() == false) { + cores.emplace_back(std::move(curCore)); + } + } + curCores.clear(); + // Now, do the reduction of each output pixel tile + for (size_t outTile = 0; outTile < outputTileCount; outTile++) { + for (size_t out_x = 0; out_x < output_w; out_x++) { + for (size_t out_y = 0; out_y < output_h; out_y++) { + // First, check if some producers are within the same core. If this is + // true, `Core::addMVM` have already done the reduction within-core. + // This means that we only need to consider the last producer for that + // core. + + std::unordered_map withinCoreReducedProducers; + for (auto producer : producers[outTile][out_x][out_y]) { + withinCoreReducedProducers[producer.core->coreId] = producer; + } + + // Now, we need to apply inter-core reduction + + // Base case with one producer + if (withinCoreReducedProducers.size() == 1) { + // TODO: Add the bias and apply mapping (if present) + + auto singleProducer = withinCoreReducedProducers.begin()->second; + // Use last producer as the final result + auto reducedValue = + singleProducer.core->makeResultRemappable(singleProducer.value); + outputTiles[outTile][out_x][out_y] = reducedValue; + continue; + } + + // TODO: This is a linear reduction, not a tree reduction. We can do + // better: a tree reduction would make more computations happen in + // parallel. + + Producer_t lastProducer = withinCoreReducedProducers.begin()->second; + + auto it = withinCoreReducedProducers.begin(); + it++; + while (it != withinCoreReducedProducers.end()) { + + Producer_t curProducer = it->second; + + shared_ptr core1; + shared_ptr core2; + Value core1Value; + Value core2Value; + + auto lastProducerCoreId = lastProducer.core->coreId; + auto curProducerCoreId = curProducer.core->coreId; + + assert(lastProducerCoreId != curProducerCoreId && + "We should have already applied within-core reduction, how " + "could we have same cores here?"); + + // Sort the cores by coreId + if (curProducerCoreId < lastProducerCoreId) { + core1 = curProducer.core; + core1Value = curProducer.value; + core2 = lastProducer.core; + core2Value = lastProducer.value; + } else { + core1 = lastProducer.core; + core1Value = lastProducer.value; + core2 = curProducer.core; + core2Value = curProducer.value; + } + + auto newCoreRes = core1->makeResultRemappable(core1Value); + auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); + + rewriter.setInsertionPointAfterValue(core2Value); + Value vaddRes = + rewriter.create(core2Value.getLoc(), + core2Value.getType(), core2Value, secondCoreBlockArg); + + lastProducer = {vaddRes, core2}; + + it++; + } + + // TODO: Add the bias and apply mapping (if present) + + // Use last producer as the final result + auto reducedValue = + lastProducer.core->makeResultRemappable(lastProducer.value); + outputTiles[outTile][out_x][out_y] = reducedValue; + } + } + } + + // Now, we need to turn the cores into a spatial::SpatWeightedCompute. + rewriter.setInsertionPointAfter(conv); + spatial::SpatWeightedCompute lastWComputeOp; + for (auto &core : cores) { + lastWComputeOp = core->createWComputeOp(loc); + core->remapResults(); + rewriter.setInsertionPointAfter(lastWComputeOp); + } + + for (auto &core : cores) { + core->addRemappedOperands(); + } + + // Set the insertion point after the last WComputeOp. + rewriter.setInsertionPointAfter(lastWComputeOp); + SmallVector tilesToConcat; + tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); + for (size_t outX = 0; outX < output_h; outX++) + for (size_t outY = 0; outY < output_w; outY++) + for (size_t outTile = 0; outTile < outputTileCount; outTile++) + tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); + + Value outputImage = rewriter.create( + loc, conv.getY().getType(), tilesToConcat); + + // Value outputImage = + // createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); + + // If no mapping (activation) was applied, just replace ConvOp + // if (mapOperation == MapOperations::None) { + // rewriter.replaceOp(conv, outputImage); + // } else { + // // If mapping was applied, erase ConvOp and replace the mapping op + // rewriter.eraseOp(conv); + // rewriter.replaceOp(firstUserOp, outputImage); + // } + + return success(); + } +}; + +void populateTilingConvOpPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp new file mode 100644 index 0000000..df70637 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp @@ -0,0 +1,430 @@ +#include "Compiler/PimCompilerOptions.hpp" +#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Dialect/Spatial/SpatialOps.hpp" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include + +using namespace mlir; +using namespace std; + +namespace onnx_mlir { + +/** + * @brief A pattern to tile the convolution operation into a series of compute + * units, each one of which applies filters to a subset of the input + * tensor. Results are also reduced and concatenated to form the final + * output tensor. + */ +struct ExperimentalONNXConvOpTile : public OpConversionPattern { + ExperimentalONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, + ConversionPatternRewriter &rewriter) const final { + + // --------------------------------- // + // --- READ OPERATION PARAMETERS --- // + // --------------------------------- // + + // To get each crossbar's weights, we need to slice the weights tensor. + // - Along the input tiles. + // - Along the output tiles. + // - Along the filter x position. + // - Along the filter y position. + ShapedType inputType = cast(convAdaptor.getX().getType()); + ShapedType outputType = cast(conv.getY().getType()); + ShapedType weightsType = cast(convAdaptor.getW().getType()); + + // TODO: Address bigger batches. + assert(GET_IMAGE_N(inputType) == 1 && "Batch size must be 1" + "for convolution."); + + // TODO: Address replication. + assert(coresCount.getValue() == -1 && + "Replication is not yet supported for convolution."); + + // TODO: Address bias addition. + + ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize); + ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize); + size_t kernelWidth = GET_KERNEL_WIDTH(weightsType); + size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType); + + // Assert that the kernel is square. + assert(kernelWidth == kernelHeight && "Only square kernels are supported."); + + // -------------------------------- // + // --- SLICE THE WEIGHTS TENSOR --- // + // -------------------------------- // + + // The core idea of this stage is classifying the weights by input and + // output tile. This is because we want the applyFilters operations to be + // tile agnostic, to keep the subsequent lowering stages as simple as + // possible. This data structure does this weight classification: + // - The outer map is indexed by input tile. + // - The inner map is indexed by output tile. + // - The SmallVector contains the weights for the filter. + map>> weightsGroups; + + // During all slicing operations within this stage, we'll use the same + // strides for all dimensions. + SmallVector slicingStrides(4, rewriter.getIndexAttr(1)); + + ldiv_t itc = inputTileCount; + ldiv_t otc = outputTileCount; + + // - Slicing along the input tiles. + // - Slicing along the output tiles. + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize; + for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) { + long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize; + + // The loop above also sets the crossbar's used width and height, + // checking if we're at the last crossbar and if it's incomplete. + + long outputTile = ot; + long inputTile = it; + + // Create the slicing sizes. + SmallVector slicingSizes{ + /* 0 */ rewriter.getIndexAttr(crossbarHeight), + /* 1 */ rewriter.getIndexAttr(crossbarWidth), + /* 2 */ rewriter.getIndexAttr(1), + /* 3 */ rewriter.getIndexAttr(1)}; + + // - Slicing along the filter x position. + // - Slicing along the filter y position. + for (size_t filterX = 0; filterX < kernelWidth; ++filterX) { + for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { + + // Create the slicing offsets. + SmallVector slicingOffsets{ + /* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), + /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), + /* 2 */ rewriter.getIndexAttr(filterX), + /* 3 */ rewriter.getIndexAttr(filterY)}; + + // Create the slice extraction operation. + auto extractSliceOp = rewriter.create( + conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, + slicingStrides); + + // Add a note to the extractSliceOp, with the filterX and filterY. + weightsGroups[inputTile][outputTile].push_back(extractSliceOp); + } + } + } + } + + // TODO: Tree reduction for compute reduction should be implemented. + + // -------------------------------- // + // --- CREATE ALL COMPUTE UNITS --- // + // -------------------------------- // + + // Keep track of input slicing operations to avoid duplication across + // all compute units (global slices). + map globalSlices; + + // Keep track of all partial compute results. + map globalPartialResults; + + // Use a weight subdivider to extract groups of weights for each compute + // unit. We'll keep extracting groups until no more weights are left. + WeightSubdivider weightSubdivider(weightsGroups); + while (!weightSubdivider.isEmpty()) { + + // -------------------------------- // + // --- BEGIN A NEW COMPUTE UNIT --- // + // -------------------------------- // + + // Get the next group of weights for the compute unit. + SmallVector weightsGroups = + weightSubdivider.popGroups(crossbarCountInCore.getValue()); + + SmallVector computeWeights; + SmallVector computeOperands; + + // ------------------------------ // + // --- SLICE THE INPUT TENSOR --- // + // ------------------------------ // + + // Note each tile's index in the compute unit arguments. + map inputTileIndices; + map outputTileIndices; + map reductionTileIndices; // Incoming partial results. + + // Iterate over all weights groups for this compute unit. + map localSlices; // WRT the current compute unit. + for (auto group : weightsGroups) { + for (Value weight : group.weights) { + computeWeights.push_back(weight); + } + + // There might be multiple weight groups for the same input tile, so if + // we've already added the input tile, skip it. + if (localSlices.find(group.inputTile) != localSlices.end()) { + continue; + } + + // We might have already sliced the input tensor for some other compute + // unit, so if we have, reuse the slicing operation without creating a + // new one. + if (globalSlices.find(group.inputTile) != globalSlices.end()) { + computeOperands.push_back(globalSlices[group.inputTile]); + localSlices[group.inputTile] = globalSlices[group.inputTile]; + continue; + } + + // Create the input tensor slicing offsets. + SmallVector slicingOffsets{ + /* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. + /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), + /* 2 */ rewriter.getIndexAttr(0), + /* 3 */ rewriter.getIndexAttr(0)}; + + // Create the input tensor slicing sizes. + size_t tilingSize = group.inputTile == inputTileCount.quot + ? inputTileCount.rem + : crossbarSize; + SmallVector slicingSizes{ + /* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. + /* 1 */ rewriter.getIndexAttr(tilingSize), + /* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), + /* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))}; + + // Create the slice extraction operation. + auto extractSliceOp = rewriter.create( + conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, + slicingStrides); + + computeOperands.push_back(extractSliceOp); + + // Update slicing maps. + globalSlices[group.inputTile] = extractSliceOp; + localSlices[group.inputTile] = extractSliceOp; + + // Update the input tile index. + inputTileIndices[group.inputTile] = computeOperands.size() - 1; + } + + // ------------------------------- // + // --- PREPARE THE OUTPUT TYPE --- // + // ------------------------------- // + + // Fill the compute output's type by looking at the output tiles. + SmallVector computeOutputType; + for (TaggedWeights group : weightsGroups) { + + // There might be multiple weight groups for the same output tile, so if + // we've already added the output tile, skip it. + if (outputTileIndices.find(group.outputTile) != + outputTileIndices.end()) { + continue; + } + + // Additionally, after adding the input slices as operands, also add any + // compatible partial results from previous compute units. + if (globalPartialResults.find(group.outputTile) != + globalPartialResults.end()) { + computeOperands.push_back(globalPartialResults[group.outputTile]); + reductionTileIndices[group.outputTile] = computeOperands.size() - 1; + } + + // Define the output shape for this group. + long outputTileSize = group.outputTile == outputTileCount.quot + ? outputTileCount.rem + : crossbarSize; + + // TODO: Address non-same padding. + SmallVector outputShapeArray{ + /* 0 */ 1, // Batch size is always 1. + /* 1 */ outputTileSize, + /* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed. + /* 3 */ GET_IMAGE_HEIGHT(outputType)}; + + auto elementType = + dyn_cast(conv.getY().getType()).getElementType(); + + computeOutputType.push_back( + RankedTensorType::get(outputShapeArray, elementType)); + + outputTileIndices[group.outputTile] = computeOutputType.size() - 1; + } + + // ----------------------------- // + // --- FILL THE COMPUTE UNIT --- // + // ----------------------------- // + + // Create the compute unit. + spatial::SpatWeightedCompute currentCompute = + rewriter.create(conv.getLoc(), + computeOutputType, computeWeights, computeOperands); + + // Create a new block for the compute unit and add the operands. + Block *block = rewriter.createBlock(¤tCompute.getRegion()); + rewriter.setInsertionPointToStart(block); + for (Value operand : computeOperands) { + block->addArgument(operand.getType(), conv->getLoc()); + } + + // Initialize a map of local partial results. + map localPartialResults; // WRT the current compute unit. + + // If we have any reduction tiles, add them to the local partial results. + for (auto reductionTileIndex : reductionTileIndices) { + localPartialResults[reductionTileIndex.first] = + block->getArgument(reductionTileIndex.second); + } + + // Add all the applyFilters operations to the block. + for (TaggedWeights group : weightsGroups) { + + // Get the outputType for this group. + Type outputType = + computeOutputType[outputTileIndices[group.outputTile]]; + + // Create an apply filters operation. + BlockArgument blockArgument = + block->getArgument(inputTileIndices[group.inputTile]); + + // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, + // ... As many weights as the size of group.weights. + SmallVector weightIndices; + for (size_t i = 0; i < group.weights.size(); ++i) { + weightIndices.push_back(group.startingCrossbarIndex + i); + } + + SmallVector xKerPos; + SmallVector yKerPos; + for (auto weight : group.weights) { + // Assert that the weight is an extract_slice operation. + auto extractSliceOp = weight.getDefiningOp(); + assert(extractSliceOp && "Weight is not an extract_slice operation."); + + // Get the filter x and y positions from the extract_slice operation. + auto offsets = extractSliceOp.getStaticOffsets(); + xKerPos.push_back(offsets[2]); + yKerPos.push_back(offsets[3]); + } + + ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices); + ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); + ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); + + Value result = + rewriter.create(conv.getLoc(), outputType, + weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument); + + // Perform local reduction if necessary. + if (localPartialResults.find(group.outputTile) != + localPartialResults.end()) { + + result = rewriter.create(conv.getLoc(), + result.getType(), localPartialResults[group.outputTile], result); + } + + // Update the partial results map. + localPartialResults[group.outputTile] = result; + } + + // Add a yield operation to the block by concatenating the partial + // results. + SmallVector applyFiltersResults; + for (size_t i = 0; i < computeOutputType.size(); ++i) { + long outputTile; + + // Given an output tile index, find the corresponding output tile. + for (auto outputTileIndex : outputTileIndices) { + if (outputTileIndex.second == i) { + outputTile = outputTileIndex.first; + break; + } + } + + // Get that tile's partial result and add it to the list. + applyFiltersResults.push_back(localPartialResults[outputTile]); + } + + // Create the yield operation with the given results. + rewriter.create(conv.getLoc(), applyFiltersResults); + + // Update the global partial results map. + for (size_t i = 0; i < applyFiltersResults.size(); ++i) { + long outputTile; + + // Given an output tile index, find the corresponding output tile. + for (auto outputTileIndex : outputTileIndices) { + if (outputTileIndex.second == i) { + outputTile = outputTileIndex.first; + break; + } + } + + globalPartialResults[outputTile] = currentCompute.getResult(i); + } + + // Move the rewrite cursor out of the block. + rewriter.setInsertionPointAfter(currentCompute); + } + + // ------------------------------ // + // --- CONCATENATE THE OUTPUT --- // + // ------------------------------ // + + // Turn the values into a SmallVector. + SmallVector outputValues; + for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); + ++i) { + outputValues.push_back(globalPartialResults[i]); + } + + // Assert that the number of output values is correct. + assert(outputValues.size() > 0 && + "No output values were generated for the convolution."); + + // If the conv's user is a ReLU... + if (conv->hasOneUse()) { + Operation *user = *conv->getUsers().begin(); + if (auto relu = dyn_cast(user)) { + // ...then we can just replace the ReLU with the concatenation. + rewriter.replaceOp(relu, + rewriter.create(conv.getLoc(), 1, outputValues)); + + // And erase the convolution. + rewriter.eraseOp(conv); + return success(); + } + } + + // Return the final output. + rewriter.replaceOp(conv, + rewriter.create(conv.getLoc(), 1, outputValues)); + + return success(); + } +}; + +/** + * @brief Populate the tiling pattern for a convolution operation. + * + * @param patterns The pattern set to populate. + * @param ctx The MLIR context. + */ +void populateExperimentalTilingConvOpPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp new file mode 100644 index 0000000..9292875 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp @@ -0,0 +1,400 @@ +#include "Compiler/PimCompilerOptions.hpp" +#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include + +using namespace mlir; +using namespace std; + +namespace onnx_mlir { + +struct ExperimentalGemmConversionPattern + : public OpConversionPattern { + ExperimentalGemmConversionPattern(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + // --------------------------------- // + // --- READ OPERATION PARAMETERS --- // + // --------------------------------- // + + // To get each crossbar's weights, we need to slice the weights tensor. + // - Along the input tiles. + // - Along the output tiles. + // - Along the filter x position. + // - Along the filter y position. + ShapedType inputType = cast(adaptor.getA().getType()); + ShapedType outputType = cast(gemmOp.getY().getType()); + ShapedType matrixType = cast(adaptor.getB().getType()); + + // TODO: Address bigger batches. + assert(inputType.getShape()[0] == 1 && + "Only batch size of 1 is supported for GEMM."); + + // TODO: Address replication. + assert(coresCount.getValue() == -1 && + "Replication is not yet supported for GEMM."); + + // TODO: Address bias addition. + + assert(inputType.getShape()[1] == matrixType.getShape()[0] && + "Input tile size must match the matrix's row size."); + + ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize); + ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize); + size_t kernelWidth = 1; + size_t kernelHeight = 1; + + // Assert that the kernel is square. + assert(kernelWidth == kernelHeight && "Only square kernels are supported."); + + // -------------------------------- // + // --- SLICE THE WEIGHTS TENSOR --- // + // -------------------------------- // + + // The core idea of this stage is classifying the weights by input and + // output tile. This is because we want the applyFilters operations to be + // tile agnostic, to keep the subsequent lowering stages as simple as + // possible. This data structure does this weight classification: + // - The outer map is indexed by input tile. + // - The inner map is indexed by output tile. + // - The SmallVector contains the weights for the filter. + map>> weightsGroups; + + // During all slicing operations within this stage, we'll use the same + // strides for all dimensions. + SmallVector slicingStrides(2, rewriter.getIndexAttr(1)); + + ldiv_t itc = inputTileCount; + ldiv_t otc = outputTileCount; + + // - Slicing along the input tiles. + // - Slicing along the output tiles. + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize; + for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) { + long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize; + + // The loop above also sets the crossbar's used width and height, + // checking if we're at the last crossbar and if it's incomplete. + + long outputTile = ot; + long inputTile = it; + + // Create the slicing sizes. + SmallVector slicingSizes{ + /* 0 */ rewriter.getIndexAttr(crossbarHeight), + /* 1 */ rewriter.getIndexAttr(crossbarWidth), + /* 2 */ /* rewriter.getIndexAttr(1), */ + /* 3 */ /* rewriter.getIndexAttr(1) */}; + + // - Slicing along the filter x position. + // - Slicing along the filter y position. + for (size_t filterX = 0; filterX < kernelWidth; ++filterX) { + for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { + + // Create the slicing offsets. + SmallVector slicingOffsets{ + /* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), + /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), + /* 2 */ /* rewriter.getIndexAttr(filterX), */ + /* 3 */ /* rewriter.getIndexAttr(filterY) */}; + + // Create the slice extraction operation. + auto extractSliceOp = rewriter.create( + gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, + slicingStrides); + + // Add a note to the extractSliceOp, with the filterX and filterY. + weightsGroups[inputTile][outputTile].push_back(extractSliceOp); + } + } + } + } + + // TODO: Tree reduction for compute reduction should be implemented. + + // -------------------------------- // + // --- CREATE ALL COMPUTE UNITS --- // + // -------------------------------- // + + // Keep track of input slicing operations to avoid duplication across + // all compute units (global slices). + map globalSlices; + + // Keep track of all partial compute results. + map globalPartialResults; + + // Use a weight subdivider to extract groups of weights for each compute + // unit. We'll keep extracting groups until no more weights are left. + WeightSubdivider weightSubdivider(weightsGroups); + while (!weightSubdivider.isEmpty()) { + + // -------------------------------- // + // --- BEGIN A NEW COMPUTE UNIT --- // + // -------------------------------- // + + // Get the next group of weights for the compute unit. + SmallVector weightsGroups = + weightSubdivider.popGroups(crossbarCountInCore.getValue()); + + SmallVector computeWeights; + SmallVector computeOperands; + + // ------------------------------ // + // --- SLICE THE INPUT TENSOR --- // + // ------------------------------ // + + // Note each tile's index in the compute unit arguments. + map inputTileIndices; + map outputTileIndices; + map reductionTileIndices; // Incoming partial results. + + // Iterate over all weights groups for this compute unit. + map localSlices; // WRT the current compute unit. + for (auto group : weightsGroups) { + for (Value weight : group.weights) { + computeWeights.push_back(weight); + } + + // There might be multiple weight groups for the same input tile, so if + // we've already added the input tile, skip it. + if (localSlices.find(group.inputTile) != localSlices.end()) { + continue; + } + + // We might have already sliced the input tensor for some other compute + // unit, so if we have, reuse the slicing operation without creating a + // new one. + if (globalSlices.find(group.inputTile) != globalSlices.end()) { + computeOperands.push_back(globalSlices[group.inputTile]); + localSlices[group.inputTile] = globalSlices[group.inputTile]; + continue; + } + + // Create the input tensor slicing offsets. + SmallVector slicingOffsets{ + /* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. + /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), + /* 2 */ /* rewriter.getIndexAttr(0), */ + /* 3 */ /* rewriter.getIndexAttr(0) */}; + + // Create the input tensor slicing sizes. + size_t tilingSize = group.inputTile == inputTileCount.quot + ? inputTileCount.rem + : crossbarSize; + SmallVector slicingSizes{ + /* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. + /* 1 */ rewriter.getIndexAttr(tilingSize), + /* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */ + /* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */}; + + // Create the slice extraction operation. + auto extractSliceOp = + rewriter.create(gemmOp.getLoc(), + adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides); + + computeOperands.push_back(extractSliceOp); + + // Update slicing maps. + globalSlices[group.inputTile] = extractSliceOp; + localSlices[group.inputTile] = extractSliceOp; + + // Update the input tile index. + inputTileIndices[group.inputTile] = computeOperands.size() - 1; + } + + // ------------------------------- // + // --- PREPARE THE OUTPUT TYPE --- // + // ------------------------------- // + + // Fill the compute output's type by looking at the output tiles. + SmallVector computeOutputType; + for (TaggedWeights group : weightsGroups) { + + // There might be multiple weight groups for the same output tile, so if + // we've already added the output tile, skip it. + if (outputTileIndices.find(group.outputTile) != + outputTileIndices.end()) { + continue; + } + + // Additionally, after adding the input slices as operands, also add any + // compatible partial results from previous compute units. + if (globalPartialResults.find(group.outputTile) != + globalPartialResults.end()) { + computeOperands.push_back(globalPartialResults[group.outputTile]); + reductionTileIndices[group.outputTile] = computeOperands.size() - 1; + } + + // Define the output shape for this group. + long outputTileSize = group.outputTile == outputTileCount.quot + ? outputTileCount.rem + : crossbarSize; + + // TODO: Address non-same padding. + SmallVector outputShapeArray{ + /* 0 */ 1, // Batch size is always 1. + /* 1 */ outputTileSize, + /* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed. + /* 3 */ /* GET_IMAGE_HEIGHT(outputType) */}; + + auto elementType = dyn_cast(gemmOp.getY().getType()) + .getElementType(); + + computeOutputType.push_back( + RankedTensorType::get(outputShapeArray, elementType)); + + outputTileIndices[group.outputTile] = computeOutputType.size() - 1; + } + + // ----------------------------- // + // --- FILL THE COMPUTE UNIT --- // + // ----------------------------- // + + // Create the compute unit. + spatial::SpatWeightedCompute currentCompute = + rewriter.create(gemmOp.getLoc(), + computeOutputType, computeWeights, computeOperands); + + // Create a new block for the compute unit and add the operands. + Block *block = rewriter.createBlock(¤tCompute.getRegion()); + rewriter.setInsertionPointToStart(block); + for (Value operand : computeOperands) { + block->addArgument(operand.getType(), gemmOp->getLoc()); + } + + // Initialize a map of local partial results. + map localPartialResults; // WRT the current compute unit. + + // If we have any reduction tiles, add them to the local partial results. + for (auto reductionTileIndex : reductionTileIndices) { + localPartialResults[reductionTileIndex.first] = + block->getArgument(reductionTileIndex.second); + } + + // Add all the applyFilters operations to the block. + for (TaggedWeights group : weightsGroups) { + + // Get the outputType for this group. + Type outputType = + computeOutputType[outputTileIndices[group.outputTile]]; + + // Create an apply filters operation. + BlockArgument blockArgument = + block->getArgument(inputTileIndices[group.inputTile]); + + // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, + // ... As many weights as the size of group.weights. + SmallVector weightIndices; + for (size_t i = 0; i < group.weights.size(); ++i) { + weightIndices.push_back(group.startingCrossbarIndex + i); + } + + SmallVector xKerPos; + SmallVector yKerPos; + for (auto weight : group.weights) { + // Assert that the weight is an extract_slice operation. + auto extractSliceOp = weight.getDefiningOp(); + assert(extractSliceOp && "Weight is not an extract_slice operation."); + + // Get the filter x and y positions from the extract_slice operation. + xKerPos.push_back(0); + yKerPos.push_back(0); + } + + ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices); + ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); + ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); + + Value result = rewriter.create(gemmOp.getLoc(), + outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, + blockArgument); + + // Perform local reduction if necessary. + if (localPartialResults.find(group.outputTile) != + localPartialResults.end()) { + + result = rewriter.create(gemmOp.getLoc(), + result.getType(), localPartialResults[group.outputTile], result); + } + + // Update the partial results map. + localPartialResults[group.outputTile] = result; + } + + // Add a yield operation to the block by concatenating the partial + // results. + SmallVector applyFiltersResults; + for (size_t i = 0; i < computeOutputType.size(); ++i) { + long outputTile; + + // Given an output tile index, find the corresponding output tile. + for (auto outputTileIndex : outputTileIndices) { + if (outputTileIndex.second == i) { + outputTile = outputTileIndex.first; + break; + } + } + + // Get that tile's partial result and add it to the list. + applyFiltersResults.push_back(localPartialResults[outputTile]); + } + + // Create the yield operation with the given results. + rewriter.create(gemmOp.getLoc(), applyFiltersResults); + + // Update the global partial results map. + for (size_t i = 0; i < applyFiltersResults.size(); ++i) { + long outputTile; + + // Given an output tile index, find the corresponding output tile. + for (auto outputTileIndex : outputTileIndices) { + if (outputTileIndex.second == i) { + outputTile = outputTileIndex.first; + break; + } + } + + globalPartialResults[outputTile] = currentCompute.getResult(i); + } + + // Move the rewrite cursor out of the block. + rewriter.setInsertionPointAfter(currentCompute); + } + + // ------------------------------ // + // --- CONCATENATE THE OUTPUT --- // + // ------------------------------ // + + // Turn the values into a SmallVector. + SmallVector outputValues; + for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); + ++i) { + outputValues.push_back(globalPartialResults[i]); + } + + // Assert that the number of output values is correct. + assert(outputValues.size() > 0 && + "No output values were generated for the GEMM operation."); + + // Return the final output. + rewriter.replaceOp(gemmOp, + rewriter.create(gemmOp.getLoc(), 1, outputValues)); + + return success(); + } +}; + +void populateGemmToConvConversionPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp new file mode 100644 index 0000000..c387ed8 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -0,0 +1,317 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include + +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; + +struct ONNXGemmOpTile : public OpConversionPattern { + ONNXGemmOpTile(MLIRContext* ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { + Location gemmLoc = gemmOp.getLoc(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + Value out = gemmOp.getY(); + + float alpha = adaptor.getAlpha().convertToFloat(); + float beta = adaptor.getBeta().convertToFloat(); + bool transA = adaptor.getTransA(); + bool transB = adaptor.getTransB(); + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto outType = cast(out.getType()); + + RankedTensorType cType = nullptr; + bool hasC = !isa(c.getDefiningOp()); + if (hasC) { + cType = cast(c.getType()); + assert("Only support 2 tensor for C" && cType.getRank() == 2); + } + + assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() + && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); + + if (transA) { + auto aShape = aType.getShape(); + auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); + a = rewriter.create(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); + } + if (transB) { + auto bShape = bType.getShape(); + auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); + b = rewriter.create(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + } + + if (alpha != 1.0f) { + auto alphaTensorType = RankedTensorType::get({1, 1}, cast(a.getType()).getElementType()); + auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); + auto alphaTensor = rewriter.create(gemmLoc, alphaTensorType, alphaTensorValue); + a = rewriter.create(gemmLoc, a.getType(), a, alphaTensor); + } + if (hasC && beta != 1.0f) { + auto betaTensorType = RankedTensorType::get({1, 1}, cast(c.getType()).getElementType()); + auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); + auto betaTensor = rewriter.create(gemmLoc, betaTensorType, betaTensorValue); + c = rewriter.create(gemmLoc, c.getType(), c, betaTensor); + } + + auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); + auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); + auto bNumVSlices = aNumHSlices; + auto bLastVSliceSize = aLastHSliceSize; + auto cNumHSlices = bNumHSlices; + auto cLastHSliceSize = bLastHSliceSize; + auto outNumHSlices = cNumHSlices; + auto outLastHSliceSize = cLastHSliceSize; + + const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); + + DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); + + DenseMap>> bTiles = + tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); + + SmallVector cHSlices; + if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) + c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); + if (hasC) + cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); + + RankedTensorType outHSliceType = + RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); + RankedTensorType outLastHSliceType = + RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); + + SmallVector outHSlices; + outHSlices.reserve(outNumHSlices); + for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { + RankedTensorType currOutHSliceType = outHSliceType; + if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) + currOutHSliceType = outLastHSliceType; + + SmallVector partialResults; + partialResults.reserve(coresPerVSlice); + for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { + SmallVector weights; + weights.reserve(aHSlices[coreId].size()); + + for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) + weights.push_back(bTiles[outSliceId][coreId][aSliceId]); + + auto computeOp = + rewriter.create(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); + + auto* computeBlock = new Block(); + for (auto aHSlice : aHSlices[coreId]) + computeBlock->addArgument(aHSlice.getType(), gemmLoc); + computeOp.getBody().push_back(computeBlock); + rewriter.setInsertionPointToStart(computeBlock); + + auto computeArgs = computeBlock->getArguments(); + SmallVector vmmOutputs; + vmmOutputs.reserve(computeArgs.size()); + for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) + vmmOutputs.push_back( + rewriter.create(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); + assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); + + Value partialVmmSum = sumTensors(vmmOutputs, rewriter); + rewriter.create(gemmLoc, partialVmmSum); + rewriter.setInsertionPointAfter(computeOp); + + partialResults.push_back(computeOp.getResult(0)); + } + + if (hasC) { + Value cHSlice = cHSlices[outSliceId]; + partialResults.push_back(cHSlice); + } + + auto reduceComputeOp = + rewriter.create(gemmLoc, currOutHSliceType, SmallVector(), partialResults); + + auto* reduceBlock = new Block(); + for (auto partialResult : partialResults) + reduceBlock->addArgument(partialResult.getType(), gemmLoc); + reduceComputeOp.getBody().push_back(reduceBlock); + rewriter.setInsertionPointToStart(reduceBlock); + + auto blockArgs = reduceBlock->getArguments(); + Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); + rewriter.create(gemmLoc, outHSlice); + rewriter.setInsertionPointAfter(reduceComputeOp); + + outHSlices.push_back(reduceComputeOp.getResult(0)); + } + + rewriter.setInsertionPoint(gemmOp); + auto concatOp = rewriter.create(gemmLoc, /*axis=*/1, outHSlices); + rewriter.replaceOp(gemmOp, concatOp); + return success(); + } + +private: + /** + * Resolves the ONNXExpOp from the use chain of the given start value. + * + * This function traverses the use chain of the start value until it finds an + * ONNXExpOp. It returns the value of the ONNXExpOp. + * + * @param startValue The starting value of the use chain. + * @return The value of the ONNXExpOp found in the use chain. + */ + static Value resolveONNXExpOpFromUseChain(Value startValue) { + Value walker = startValue; + + while (!llvm::isa(walker.getDefiningOp())) { + walker = walker.getDefiningOp()->getOperand(0); + + assert(walker && walker.getDefiningOp() + && "Unwinded the whole chain of operations while trying to " + "find ONNXExpOp, but did not find it"); + } + + // Make sure the dividend is actually produced by an ONNXExpOp + assert(llvm::isa(walker.getDefiningOp()) + && "Old output tile (softmax reducer) is not produced by an " + "ONNXExpOp"); + + return walker; + } + + // Softmax is a special case, as it requires another reduction after the + // first one. In the cores, `applyReducePattern` already applied + // f(x) = exp(x) to each tile. This mean that now we just need to + // reduce-sum these tiles, and then divide each tile by the reduced sum, + // which is propagated back to the cores via a broadcast channel. + LogicalResult softmaxReductionApplication(SmallVector& outputOpsAndResNums, + Value& softmaxChannel, + ConversionPatternRewriter& rewriter, + SpatialReducer& reducer, + ONNXGemmOp& gemmOp, + Location& loc) const { + + // TODO: Check case with one compute op + + // Cast vector of Value into vector of ComputeOp + SmallVector softmaxOpsToReduce = + llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) { + return std::make_pair(cast(computeAndResNum.first), computeAndResNum.second); + })); + + RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr); + const TensorType scalarTensorType = tensorTypeBuilder; + + reducer.applyReducePattern( + softmaxOpsToReduce, + [&](Value a, Value b) { return rewriter.create(loc, scalarTensorType, a, b); }, + /* preprocess = */ + [&](Value a) { return rewriter.create(loc, scalarTensorType, a); }, + [&](Value softmaxDivisor) { + // Signal that this is the compute with the softmax divisor + auto computeOp = cast(softmaxDivisor.getDefiningOp()->getParentOp()); + computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr()); + + // Broadcast the divisor to all the cores + rewriter.setInsertionPointAfterValue(softmaxDivisor); + rewriter.create(loc, softmaxChannel, softmaxDivisor); + + /* + * softmaxDividend = onnx.exp (...) + * sum = spat.SumOp(softmaxDividend) + * [following can be repeated N times, thus walk the use chain] + * softmaxDivisor = spat.sadd(sum, ...) + */ + Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0)); + + // Make sure the dividend is actually produced by an ONNXExpOp + assert(llvm::isa(softmaxDividend.getDefiningOp()) + && "Dividend of softmax reduction is not an ONNXExpOp"); + + // Do not divide here, divide after this + return softmaxDivisor; + }); + + // In all the cores, insert a ChannelRecvOp and divide the output tile by + // the reduced denominator. + outputOpsAndResNums.clear(); + outputOpsAndResNums.reserve(softmaxOpsToReduce.size()); + for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) { + + auto yieldOp = cast(computeToDivideOpAndResNum.first.getBody().front().getTerminator()); + + Value divisor; + + // Check if this compute contains the softmax divisor: if so, find the + // ChannelBroadcastSendOp, otherwise receive the value from the channel + // using ChannelBroadcastReceiveOp + if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) { + + bool found = false; + for (auto broadcastOp : + computeToDivideOpAndResNum.first.getBody().front().getOps()) { + assert(found == false + && "More than one ChannelBroadcastSendOp in " + "compute? How is this possible?"); + found = true; + + divisor = broadcastOp.getData(); + } + + assert(found + && "No ChannelBroadcastSendOp in compute where softmax " + "divisor was specified to be?"); + } + else { + rewriter.setInsertionPoint(yieldOp); + divisor = rewriter.create(loc, scalarTensorType, softmaxChannel); + } + + // Walk the chain of operations until we find the ONNXExpOp: this is + // needed because some some may have a different amount of `VAddOp`s due + // to the tree reduction (e.g. some may have no VAddOp, some may have + // multiples) + Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second)); + + rewriter.setInsertionPoint(yieldOp); + Value newOutputTile = rewriter.create(loc, oldOutputTile.getType(), oldOutputTile, divisor); + auto yieldOperandNum = yieldOp->getNumOperands(); + yieldOp->insertOperands(yieldOperandNum, newOutputTile); + + outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum}); + } + + return success(); + } +}; + +void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp new file mode 100644 index 0000000..2530e72 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp @@ -0,0 +1,327 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +using namespace mlir; + +namespace onnx_mlir { + +template +bool hasPostProcessExperimentalPoolingWindow() { + return false; +} + +template <> +bool hasPostProcessExperimentalPoolingWindow() { + return true; +} + +template +Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter &rewriter, + Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size, + size_t tilesSkippedByPadding) { + return nullptr; +} + +template <> +Value postProcessExperimentalPoolingWindow( + ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, + Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) { + bool countIncludePad = poolOp.getCountIncludePad() == 1; + + size_t divisorNumber = + countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; + + RankedTensorType scalarTensor = + RankedTensorType::get({1}, rewriter.getF32Type()); + + // Put a spat.const before the computeOp, and use its value. We do this to be + // compatible with the current code generation, which assumes constant to be + // loaded in global memory, which is allocated by adding a spat.const OP + // directly under func.func (i.e. alongside ComputeOps) + auto computeOp = cast( + valueToDivide.getDefiningOp()->getParentOp()); + rewriter.setInsertionPoint(computeOp); + auto divisorValue = rewriter.create(loc, scalarTensor, + rewriter.getI64IntegerAttr(divisorNumber), + /* should_allocate = */ rewriter.getBoolAttr(true)); + + rewriter.setInsertionPointAfterValue(valueToDivide); + return rewriter.create( + loc, valueToDivide.getType(), valueToDivide, divisorValue); +} + +template +Value reduceInputTiles( + SmallVector &inputTiles, ConversionPatternRewriter &rewriter) { + if (inputTiles.size() == 1) { + return inputTiles[0]; + } + + if (inputTiles.size() == 2) { + return rewriter.create(inputTiles[0].getLoc(), + inputTiles[0].getType(), inputTiles[0], inputTiles[1]); + } + + SmallVector left( + inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2); + SmallVector right( + inputTiles.begin() + inputTiles.size() / 2, inputTiles.end()); + + Value leftReduced = reduceInputTiles(left, rewriter); + Value rightReduced = reduceInputTiles(right, rewriter); + + return rewriter.create( + inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced); +} + +template +struct ExperimentalPoolingBaseConverter : public OpConversionPattern { + ExperimentalPoolingBaseConverter(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Value X = adaptor.getX(); + ShapedType xShape = mlir::cast(X.getType()); + Value Y = poolOp.getResult(); + ShapedType yShape = mlir::cast(Y.getType()); + + size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h; + unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y); + unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); + unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); + + if (adaptor.getAutoPad() != "NOTSET") { + return rewriter.notifyMatchFailure( + poolOp, "auto_pad != NOTSET is deprecated."); + } + + size_t pad_x, pad_y; + auto padUnpackError = + unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) { + return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); + } + + Location loc = poolOp.getLoc(); + + size_t input_h = GET_IMAGE_HEIGHT(xShape); + size_t input_w = GET_IMAGE_WIDTH(xShape); + size_t output_h = GET_IMAGE_HEIGHT(yShape); + size_t output_w = GET_IMAGE_WIDTH(yShape); + + ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize); + + // Assert that the input is a tensor.ConcatOp. + auto concat = X.getDefiningOp(); + if (!concat) { + return rewriter.notifyMatchFailure( + poolOp, "Expected input to be a tensor.ConcatOp"); + } + + // Create a [channel_tile][x][y] array to store the input tiles. + std::map>> inputTiles; + + // For each argument of the tensor.ConcatOp, resolve the input tiles. + for (size_t y = 0; y < input_h; ++y) { + for (size_t x = 0; x < input_w; ++x) { + for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) { + size_t tilingSize = + it == tileCount.quot ? tileCount.rem : crossbarSize; + + SmallVector strides(4, rewriter.getIndexAttr(1)); + SmallVector offsets = {/* 0 */ rewriter.getIndexAttr(0), + /* 1 */ rewriter.getIndexAttr(0), + /* 2 */ rewriter.getIndexAttr(x), + /* 3 */ rewriter.getIndexAttr(y)}; + SmallVector sizes = { + /* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. + /* 1 */ rewriter.getIndexAttr(tilingSize), + /* 2 */ rewriter.getIndexAttr(1), + /* 3 */ rewriter.getIndexAttr(1)}; + + // Get the concat's operand that we want to slice. + Value concatInput = concat.getOperand(it); + Value slicedTile = rewriter.create( + loc, concatInput, offsets, sizes, strides); + + inputTiles[it][x][y] = slicedTile; + } + } + } + + // Prepare the shape of the compute's output. + ldiv_t itc = tileCount; + SmallVector outputTileTypes; + for (size_t y = 0; y < output_h; ++y) { + for (size_t x = 0; x < output_w; ++x) { + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + SmallVector outputShapeArray{ + /* 0 */ 1, // Batch size is always 1. + /* 1 */ + cast(inputTiles[it][0][0].getType()) + .getShape()[1], + /* 2 */ 1, + /* 3 */ 1}; + + auto elementType = + dyn_cast(xShape).getElementType(); + + outputTileTypes.push_back( + RankedTensorType::get(outputShapeArray, elementType)); + } + } + } + + // Create a plain value list of the input tiles. + SmallVector inputTilesList; + for (size_t y = 0; y < input_h; ++y) { + for (size_t x = 0; x < input_w; ++x) { + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + inputTilesList.push_back(inputTiles[it][y][x]); + } + } + } + + // Create a single compute to calculate the output. + auto computeOp = rewriter.create( + loc, outputTileTypes, SmallVector(), inputTilesList); + + // Create a new block for the compute unit and add the operands. + Block *block = rewriter.createBlock(&computeOp.getRegion()); + + // Fill the block arguments and keep a reference to them. + std::map>> inputTilesArgs; + for (size_t y = 0; y < input_h; ++y) { + for (size_t x = 0; x < input_w; ++x) { + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + + x * (itc.quot + (itc.rem > 0)) + it; + inputTilesArgs[it][y][x] = block->addArgument( + computeOp->getOperand(tileIndex).getType(), loc); + } + } + } + + // Begin writing in the block. + rewriter.setInsertionPointToStart(block); + + // Go through all pooling blocks. + SmallVector outputTiles; + for (size_t y = 0; y < output_h; ++y) { + for (size_t x = 0; x < output_w; ++x) { + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + size_t start_x = x * stride_x; + size_t start_y = y * stride_y; + size_t end_x = std::min(start_x + krn_w, input_w); + size_t end_y = std::min(start_y + krn_h, input_h); + + SmallVector inputTilesToReduce; + for (size_t ky = start_y; ky < end_y; ++ky) { + for (size_t kx = start_x; kx < end_x; ++kx) { + inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]); + } + } + + auto reduceResult = + reduceInputTiles(inputTilesToReduce, rewriter); + + // If the reduce op is add, we need to divide the result by the + // number of elements in the pooling window. + if (hasPostProcessExperimentalPoolingWindow()) { + // Add a spat.const before the computeOp. + rewriter.setInsertionPoint(computeOp); + auto divisorValue = rewriter.create(loc, + RankedTensorType::get({1}, rewriter.getF32Type()), + rewriter.getI64IntegerAttr(krn_w * krn_h), + rewriter.getBoolAttr(true)); + + rewriter.setInsertionPointAfter(reduceResult.getDefiningOp()); + reduceResult = rewriter.create( + loc, reduceResult.getType(), reduceResult, divisorValue); + } + outputTiles.push_back(reduceResult); + } + } + } + + // Create a YieldOp to return the output tiles. + rewriter.create(loc, outputTiles); + + // Set the rewrite cursor right after the computeOp. + rewriter.setInsertionPointAfter(computeOp); + + std::map>> computeOutput; + for (size_t y = 0; y < output_h; ++y) { + for (size_t x = 0; x < output_w; ++x) { + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + + x * (itc.quot + (itc.rem > 0)) + it; + computeOutput[it][y][x] = computeOp.getResult(tileIndex); + } + } + } + + // We'll now create spat.img.concat ops to concatenate the output tiles. + SmallVector outputTilesList; + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + SmallVector imgConcatTiles; + for (size_t y = 0; y < output_h; ++y) { + for (size_t x = 0; x < output_w; ++x) { + imgConcatTiles.push_back(computeOutput[it][y][x]); + } + } + + size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize; + + SmallVector outputShapeArray{ + /* 0 */ 1, // Batch size is always 1. + /* 1 */ (long)tilingSize, + /* 2 */ (long)output_w, + /* 3 */ (long)output_h}; + + auto elementType = dyn_cast(xShape).getElementType(); + + outputTilesList.push_back(rewriter.create(loc, + RankedTensorType::get(outputShapeArray, elementType), + imgConcatTiles)); + } + + // Create a new tensor.ConcatOp to concatenate the output tiles. + Value outputTensor = + rewriter.create(loc, 1, outputTilesList); + + rewriter.replaceOp(poolOp, outputTensor); + + return success(); + } +}; + +void populateExperimentalPoolingTilingPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert>(ctx); + patterns.insert>(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp new file mode 100644 index 0000000..58dc549 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp @@ -0,0 +1,452 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +using namespace mlir; + +namespace onnx_mlir { + +llvm::SmallPtrSet oldComputeOpsReplaced; + +Value applyReducePatternNew(SmallVector &valuesToReduce, + ConversionPatternRewriter &rewriter, + std::function reduce, + std::function preprocess, + std::function postprocess) { + // Simple case: if we have only one input, just return it + if (valuesToReduce.size() == 1) { + return valuesToReduce[0]; + } + + if (preprocess) { + for (auto &valToReduce : valuesToReduce) { + rewriter.setInsertionPointAfterValue(valToReduce); + valToReduce = preprocess(valToReduce); + } + } + + // It is possible that `valuesToReduce` contains two entries for the same + // computeOp. In this case, we need to apply the reduction within-computef + + // Keep a map between a computeOp and the last Value for this reduction + std::unordered_map lastValueForCompute; + for (auto &valToReduce : valuesToReduce) { + Operation *computeOp = valToReduce.getParentBlock()->getParentOp(); + // if (valToReduce.getDefiningOp()) { + // // If the value is defined by an operation, we take the parent + // operation computeOp = valToReduce.getDefiningOp()->getParentOp(); + // } else { + // // Otherwise it is a block argument, + // computeOp->getBlock()->getParentOp(); + // } + + assert(isa(computeOp) && "Expected a ComputeOp"); + + auto it = lastValueForCompute.find(computeOp); + + if (it != lastValueForCompute.end()) { + // If we have already seen this computeOp, apply the reduction + // within-compute + Value lastWithinComputeValue = it->second; + + if (valToReduce.getDefiningOp()->isBeforeInBlock( + lastWithinComputeValue.getDefiningOp())) { + rewriter.setInsertionPointAfterValue(lastWithinComputeValue); + } else { + rewriter.setInsertionPointAfterValue(valToReduce); + } + valToReduce = reduce(lastWithinComputeValue, valToReduce); + lastValueForCompute[computeOp] = valToReduce; + } + + lastValueForCompute[computeOp] = valToReduce; + } + + // Now, reconstruct from the map the valuesToReduce list + valuesToReduce.clear(); + valuesToReduce.reserve(lastValueForCompute.size()); + for (auto &entry : lastValueForCompute) { + valuesToReduce.push_back(entry.second); + } + + Location loc = valuesToReduce[0].getLoc(); + auto channelType = spatial::SpatChannelType::get(rewriter.getContext()); + + // Recursive algorithm to reduce the inputs to a single one: + // - Take two inputs at a time, and reduce them into a single one, updating + // the valuesToReduce list which becomes half the size. + // - Repeat until there is only one input left. + llvm::OwningArrayRef valuesToReduceRef(valuesToReduce); + while (valuesToReduceRef.size() > 1) { + SmallVector nextValuesToReduce; + nextValuesToReduce.reserve(valuesToReduceRef.size() / 2); + for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) { + auto firstValue = valuesToReduceRef[i]; + auto secondValue = valuesToReduceRef[i + 1]; + + auto firstCompute = firstValue.getParentBlock()->getParentOp(); + auto secondCompute = secondValue.getParentBlock()->getParentOp(); + + assert(isa(firstCompute)); + assert(isa(secondCompute)); + + if (secondCompute->isBeforeInBlock(firstCompute)) { + std::swap(firstValue, secondValue); + std::swap(firstCompute, secondCompute); + } + + // 1. Add a channel before the first computeOp + rewriter.setInsertionPoint(firstCompute); + auto channel = rewriter.create(loc, channelType); + + // 2. Add a sendOp after the first value + rewriter.setInsertionPointAfterValue(firstValue); + rewriter.create(loc, channel, firstValue); + + // 3. Add a receiveOp after the second value + rewriter.setInsertionPointAfterValue(secondValue); + auto receivedValue = rewriter.create( + loc, secondValue.getType(), channel); + + // 4. Apply reduction between second value and received value + rewriter.setInsertionPointAfterValue(receivedValue); + Value reduced = reduce(receivedValue, secondValue); + + nextValuesToReduce.push_back(reduced); + } + + // If we have an odd number of inputs, we need to add the last one to the + // newInputs list. + if (valuesToReduceRef.size() % 2 == 1) { + nextValuesToReduce.push_back(valuesToReduceRef.back()); + } + + // Replace the inputOps list with the new one. + valuesToReduceRef = + llvm::OwningArrayRef(std::move(nextValuesToReduce)); + } + + assert(valuesToReduceRef.size() == 1 && + "Internal error: expected a single input at this point."); + + auto finalValue = valuesToReduceRef[0]; + + if (postprocess) { + rewriter.setInsertionPointAfterValue(finalValue); + finalValue = postprocess(finalValue); + } + + return finalValue; +} + +template +bool hasPostProcessPoolingWindow() { + return false; +} + +template <> +bool hasPostProcessPoolingWindow() { + return true; +} + +template +Value postProcessPoolingWindow(ConversionPatternRewriter &rewriter, + Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size, + size_t tilesSkippedByPadding) { + return nullptr; +} + +template <> +Value postProcessPoolingWindow( + ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, + Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) { + bool countIncludePad = poolOp.getCountIncludePad() == 1; + + size_t divisorNumber = + countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; + + RankedTensorType scalarTensor = + RankedTensorType::get({1}, rewriter.getF32Type()); + + // Put a spat.const before the computeOp, and use its value. We do this to be + // compatible with the current code generation, which assumes constant to be + // loaded in global memory, which is allocated by adding a spat.const OP + // directly under func.func (i.e. alongside ComputeOps) + auto computeOp = cast( + valueToDivide.getDefiningOp()->getParentOp()); + rewriter.setInsertionPoint(computeOp); + auto divisorValue = rewriter.create(loc, scalarTensor, + rewriter.getI64IntegerAttr(divisorNumber), + /* should_allocate = */ rewriter.getBoolAttr(true)); + + rewriter.setInsertionPointAfterValue(valueToDivide); + return rewriter.create( + loc, valueToDivide.getType(), valueToDivide, divisorValue); +} + +template +struct PoolingBaseConverter : public OpConversionPattern { + PoolingBaseConverter(MLIRContext *ctx) : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Value X = adaptor.getX(); + ShapedType xShape = mlir::cast(X.getType()); + Value Y = poolOp.getResult(); + ShapedType yShape = mlir::cast(Y.getType()); + + size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h; + unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y); + unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); + unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); + + if (adaptor.getAutoPad() != "NOTSET") { + return rewriter.notifyMatchFailure( + poolOp, "auto_pad != NOTSET is deprecated."); + } + + size_t pad_x, pad_y; + auto padUnpackError = + unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) { + return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); + } + + Location loc = poolOp.getLoc(); + + size_t input_h = GET_IMAGE_HEIGHT(xShape); + size_t input_w = GET_IMAGE_WIDTH(xShape); + size_t output_h = GET_IMAGE_HEIGHT(yShape); + size_t output_w = GET_IMAGE_WIDTH(yShape); + size_t channelTileCount = + ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize; + + // 1: Tile the input tensor + // Input tiles need to be indexed by: + // a. Channel Tile + // b. Pixel `x` position + // c. Pixel `y` position + // For example: inputTiles[channelTile][x][y] + // Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH) + // Suppose that the input tensor is produced by concatenating the results of + // many ComputeOps. Get the result tiles from these ComputeOps. + SmallVector>> inputTiles(channelTileCount, + SmallVector>(input_w, SmallVector(input_h))); + + auto resolveErrorOpt = resolveImgInputTiles(X, inputTiles, channelTileCount, + channelTileRest, input_w, input_h, rewriter); + if (resolveErrorOpt.has_value()) { + return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt); + } + + // TODO: This requires a core for each input tile, which is not ideal. We + // can do better. + // If some input tiles come from the func.func operands, load + // them into a computeOp and yield them + for (size_t t = 0; t < channelTileCount; t++) { + for (size_t x = 0; x < input_w; x++) { + for (size_t y = 0; y < input_h; y++) { + if (auto extractSliceOp = + inputTiles[t][x][y].getDefiningOp()) { + Location tileLoc = extractSliceOp.getLoc(); + + auto tempComputeOp = rewriter.create( + tileLoc, extractSliceOp.getResultType(), + /* xbarWeights =*/ValueRange(), extractSliceOp.getResult()); + + Block *tempComputeOpBlock = new Block(); + tempComputeOp.getBody().push_back(tempComputeOpBlock); + auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument( + extractSliceOp.getType(), tileLoc); + + rewriter.setInsertionPointToStart(tempComputeOpBlock); + rewriter.create(tileLoc, tempComputeOpBlockArg); + rewriter.setInsertionPointAfter(tempComputeOp); + inputTiles[t][x][y] = tempComputeOp.getResult(0); + } + } + } + } + + // 2: Tile the output tensor + // Output tiles need to be indexed by: + // a. Channel Tile + // b. Pixel `x` position + // c. Pixel `y` position + // For example: outputTiles[channelTile][x][y] + // Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH) + SmallVector>> outputTiles( + channelTileCount, SmallVector>( + output_w, SmallVector(output_h, nullptr))); + + // List of values to pool for each output pixel + SmallVector valuesToPool; + + // Iterate each output tile + for (size_t outTile = 0; outTile < channelTileCount; outTile++) { + // Iterate each output pixel + for (size_t outX = 0; outX < output_w; outX++) { + for (size_t outY = 0; outY < output_h; outY++) { + + // Each output pixel tile is computed by pooling a window of input + // pixel tiles + valuesToPool.clear(); + size_t tilesSkippedByPadding = 0; + + auto [start_x, end_x] = kernel_get_start_and_end( + outX, input_w, krn_w, stride_x, dilation_x, pad_x); + auto [start_y, end_y] = kernel_get_start_and_end( + outY, input_h, krn_h, stride_y, dilation_y, pad_y); + + for (size_t inX = start_x; inX < end_x; inX += dilation_x) { + for (size_t inY = start_y; inY < end_y; inY += dilation_y) { + if (failed(verifyWithinBoundsAndPaddings( + input_w, input_h, inX, inY, pad_x, pad_y))) { + tilesSkippedByPadding++; + continue; + } + + Value inputTile = inputTiles[outTile][inX][inY]; + + Value valueToPool; + if (auto computeProducer = + inputTile.getDefiningOp()) { + + int resultNumber = getResultIndex(computeProducer, inputTile); + + auto yieldInComputeOp = cast( + computeProducer.getBody().front().getTerminator()); + valueToPool = yieldInComputeOp.getOperand(resultNumber); + } else if (auto receiveProducer = + inputTile + .getDefiningOp()) { + auto sendOpOpt = + getOtherEndOfChannel(receiveProducer, true, rewriter); + if (failed(sendOpOpt)) { + return rewriter.notifyMatchFailure(poolOp, + "ChannelReceiveOp does not have a matching " + "ChannelSendOp."); + } + auto sendOp = cast(*sendOpOpt); + + valueToPool = sendOp.getData(); + } else { + return rewriter.notifyMatchFailure(poolOp, + "Input tile for Pooling is not produced by a " + "WeightedComputeOp nor a receiveOp"); + } + + valuesToPool.push_back(valueToPool); + } + } + + assert(valuesToPool.size() != 0 && + "Pooling computed on zero tiles make no sense."); + // assert(computeOpsForPooling.size() != 1 && + // "Pooling computed on one tiles make no sense??? Or maybe + // this " "should have been simplified earlier???"); + + std::function postProcessFn = nullptr; + if (hasPostProcessPoolingWindow()) { + postProcessFn = [&](const Value prevFinalRes) { + return postProcessPoolingWindow(rewriter, loc, poolOp, + prevFinalRes, krn_h * krn_w, tilesSkippedByPadding); + }; + } + + Value reducedWithinCompute = applyReducePatternNew( + valuesToPool, rewriter, + [&](const Value lhs, const Value rhs) { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }, + nullptr, postProcessFn); + + // Send this value through a channel, and receive it in the + // `func.func`. During lowering, we will need to "move it" into the + // users computeOps + auto computeOpOfReduced = cast( + reducedWithinCompute.getDefiningOp()->getParentOp()); + + // Create a new channel before the computeOp + rewriter.setInsertionPoint(computeOpOfReduced); + auto reduceChannel = rewriter.create( + loc, spatial::SpatChannelType::get(rewriter.getContext())); + + // Send value through the channel + rewriter.setInsertionPointAfterValue(reducedWithinCompute); + rewriter.create( + loc, reduceChannel, reducedWithinCompute); + + // Receive after the computeOp + rewriter.setInsertionPointAfter(computeOpOfReduced); + auto receivedValue = rewriter.create( + loc, reducedWithinCompute.getType(), reduceChannel); + + outputTiles[outTile][outX][outY] = receivedValue; + } + } + } + + // TODO: outputTiles are not the results of the computeOps! We need to add + // them! + + std::unordered_map>> + computeOpNeedingResults; + + // Iterate each output tile + for (size_t outTile = 0; outTile < channelTileCount; outTile++) { + // Iterate each output pixel + for (size_t outX = 0; outX < output_w; outX++) { + for (size_t outY = 0; outY < output_h; outY++) { + auto outputTile = outputTiles[outTile][outX][outY]; + auto outputTileProducer = outputTile.getDefiningOp()->getParentOp(); + if (!outputTileProducer) { + return rewriter.notifyMatchFailure(poolOp, + "Output tile for Pooling is not produced by a " + "WeightedComputeOp."); + } + + computeOpNeedingResults[outputTileProducer].push_back( + std::make_tuple(outTile, outX, outY, outputTile)); + } + } + } + + Value outputImage = + createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType()); + + rewriter.replaceOp(poolOp, outputImage); + + return success(); + } +}; + +void populatePoolingTilingPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert>(ctx); + patterns.insert>(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp new file mode 100644 index 0000000..f8e44ab --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp @@ -0,0 +1,90 @@ + + +#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +namespace onnx_mlir { + +struct ReduceMeanConversionPattern + : public OpConversionPattern { + + ReduceMeanConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean, + ONNXReduceMeanV13OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + // Get the input tensor. + Value inputTensor = adaptor.getData(); + auto inputTensorType = cast(inputTensor.getType()); + + // This pattern will substitute the ONNXReduceMeanV13Op with a + // ONNXAveragePoolOp with the same input tensor and an appropriate kernel + // shape and strides. + + // To get the stride and shape of the kernel, we need to read the tensor + // shape. + int image_height = inputTensorType.getShape()[2]; + int image_width = inputTensorType.getShape()[3]; + + // Define the kernel shape and strides. + SmallVector kernelShapeVals = {image_height, image_width}; + SmallVector stridesVals = {image_height, image_width}; + SmallVector dilationsVals = {1, 1}; + + // Set the pads to 0. + SmallVector padsVals = {0, 0, 0, 0}; + + // Create the ArrayAttrs + auto kernelShape = mlir::ArrayAttr::get(rewriter.getContext(), + llvm::to_vector( + llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); + + auto strides = mlir::ArrayAttr::get(rewriter.getContext(), + llvm::to_vector( + llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); + + auto dilations = mlir::ArrayAttr::get(rewriter.getContext(), + llvm::to_vector( + llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); + + auto pads = mlir::ArrayAttr::get(rewriter.getContext(), + llvm::to_vector( + llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); + + // Create the resulting tensor type. + auto resultType = RankedTensorType::get( + /*shape=*/{inputTensorType.getShape()[0], inputTensorType.getShape()[1], + 1, 1}, + /*elementType=*/inputTensorType.getElementType()); + + // Create the ONNXAveragePoolOp. + auto averagePool = rewriter.create(reduceMean.getLoc(), + resultType, inputTensor, /*auto_pad=*/"NOTSET", + /*ceil_mode=*/0, /*count_include_pad=*/1, dilations, + /*kernel_shape=*/kernelShape, + /*pads=*/pads, /*strides=*/strides); + + // Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp. + rewriter.replaceOp(reduceMean, averagePool.getResult()); + + return success(); + } +}; + +void populateReduceMeanConversionPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td new file mode 100644 index 0000000..529d579 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td @@ -0,0 +1,79 @@ +#ifndef ONNX_TO_SPATIAL +#define ONNX_TO_SPATIAL + +#ifndef OP_BASE +include "mlir/Dialect/Tensor/IR/TensorOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "src/Dialect/ONNX/ONNX.td" +include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" +#endif // OP_BASE + +def onnxToArithConstantOp : Pat< + (ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings), + (Arith_ConstantOp $value) +>; + +//===----------------------------------------------------------------------===// +// ONNXMatMulOp to ONNXGemmOp patterns +//===----------------------------------------------------------------------===// + +def matMulAddToGemmPattern : Pat< + (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), + (ONNXGemmOp $A, $B, $C, + /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), + /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), + /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), + /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) + ) +>; + +def matMulToGemmPattern : Pat< + (ONNXMatMulOp:$matmulres $A, $B), + ( + ONNXGemmOp $A, $B, + /* C = */ (NativeCodeCall<"$_builder.create($_loc, cast(matmulres.getY().getType()).getShape(), cast(matmulres.getY().getType()).getElementType());">), + /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), + /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), + /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), + /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) + ) +>; + +//===----------------------------------------------------------------------===// +// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern +//===----------------------------------------------------------------------===// + +// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single +// ONNXConvOp with a bias. +def convAddToConvWithBiasPatternLeft : Pat< + (ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)), + (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) +>; + +def convAddToConvWithBiasPatternRight : Pat< + (ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand), + (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) +>; + +//===----------------------------------------------------------------------===// +// Operation to ignore (i.e. remove) +//===----------------------------------------------------------------------===// + +def replaceWithOperationOfValue : NativeCodeCall<"$0">; + +def removeLRNPattern : Pat< + (ONNXLRNOp $A, $_, $_, $_, $_), + (replaceWithOperationOfValue $A) +>; + +def HaveSameStaticShape: Constraint< + CPred<"onnx_mlir::haveSameStaticShape($0, $1)">, + "Two tensors have the same static shape">; + +def removeFlattenSameShapePattern : Pat< + (ONNXFlattenOp:$flattenOp $A, $axis), + (replaceWithOperationOfValue $A), + [(HaveSameStaticShape $flattenOp, $A)] +>; // Add closing parenthesis here + +#endif // ONNX_TO_SPATIAL \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp new file mode 100644 index 0000000..bd1ecac --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp @@ -0,0 +1,499 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" + +#include +#include +#include + +#include "ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +SmallVector sliceTensor( + const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { + ArrayRef shape = getTensorShape(tensorToSlice); + assert("Invalid axis" && axis < shape.size()); + + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector sizes; + sizes.reserve(shape.size()); + for (const auto size : shape) + sizes.push_back(rewriter.getIndexAttr(size)); + sizes[axis] = rewriter.getIndexAttr(sliceSize); + + long length = shape[axis]; + auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize); + SmallVector slices; + slices.reserve(numSlices); + + for (int64_t i = 0; i < numSlices; i++) { + offsets[axis] = rewriter.getIndexAttr(i * sliceSize); + if (i == numSlices - 1 && lastSliceSize != 0) + sizes[axis] = rewriter.getIndexAttr(lastSliceSize); + + Value slice = rewriter.create(loc, tensorToSlice, offsets, sizes, strides); + slices.push_back(slice); + } + + return slices; +} + +SmallVector +sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { + ArrayRef shape = getTensorShape(vectorToSlice); + assert("Not a vector" && isVectorShape(shape)); + size_t axis = shape[0] != 1 ? 0 : 1; + return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc); +} + +DenseMap> +sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) { + SmallVector slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc); + DenseMap> slicesPerCore; + for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) { + size_t coreId = sliceId / crossbarCountInCore; + slicesPerCore[coreId].push_back(slices[sliceId]); + } + return slicesPerCore; +} + +DenseMap>> tileMatrix( + Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) { + assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile))); + + DenseMap>> tiles; + + SmallVector hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc); + size_t numHSlices = hSlices.size(); + for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) { + Value hSlice = hSlices[hSliceId]; + SmallVector vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc); + for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) { + size_t coreId = vSliceId / crossbarCountInCore; + Value vSlice = vSlices[vSliceId]; + tiles[hSliceId][coreId].push_back(vSlice); + } + } + return tiles; +} + +tensor::SplatOp +broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) { + auto oldType = cast(scalarToBroadcast.getType()); + Type elementType = oldType.getElementType(); + int64_t shape[2] = {1, length}; + Type type = oldType.cloneWith(ArrayRef(shape), elementType); + + auto zero = rewriter.create(loc, 0).getResult(); + SmallVector index(oldType.getRank(), zero); + auto elementValue = rewriter.create(loc, scalarToBroadcast, index).getResult(); + + return rewriter.create(loc, type, elementValue); +} + +Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { + if (tensors.size() == 1) + return tensors[0]; + + SmallVector tensors1 = {tensors.begin(), tensors.end()}; + SmallVector tensors2; + tensors2.reserve(tensors.size() / 2); + + auto* currTensors = &tensors1; + auto* nextTensors = &tensors2; + while (currTensors->size() > 1) { + for (size_t i = 0; i < currTensors->size() - 1; i += 2) { + Value a = (*currTensors)[i]; + Value b = (*currTensors)[i + 1]; + rewriter.setInsertionPointAfterValue(b); + auto addedValue = rewriter.create(a.getLoc(), a.getType(), a, b); + nextTensors->push_back(addedValue); + } + if (currTensors->size() % 2 == 1) + nextTensors->push_back(currTensors->back()); + std::swap(currTensors, nextTensors); + nextTensors->clear(); + } + assert(currTensors->size() == 1 && "Expected a single input at this point."); + return (*currTensors)[0]; +} + +Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) { + switch (mapOp) { + case MapOperations::None: assert(false && "Invalid map operation during map operation creation."); + case MapOperations::ONNXSoftmaxOp: return rewriter.create(input.getLoc(), input.getType(), input); + case MapOperations::ONNXReluOp: return rewriter.create(input.getLoc(), input.getType(), input); + case MapOperations::ONNXLeakyReluOp: return rewriter.create(input.getLoc(), input.getType(), input); + case MapOperations::ONNXExpOp: return rewriter.create(input.getLoc(), input.getType(), input); + } +} + +void unpackOptionalPairVector(std::optional valuesArray, size_t& value1, size_t& value2) { + if (auto unpackedStrides = valuesArray) { + value1 = mlir::cast(unpackedStrides->getValue()[0]).getInt(); + value2 = mlir::cast(unpackedStrides->getValue()[1]).getInt(); + } + else { + value1 = 1; + value2 = 1; + } +} + +std::optional +unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y) { + if (valuesArray.has_value()) { + auto pads = mlir::ArrayAttr(*valuesArray); + if (pads.size() != 4) + return "pads must have 4 elements."; + + pad_x = cast(pads[2]).getInt(); + pad_y = cast(pads[3]).getInt(); + } + else { + // Default padding is 0 unless specified otherwise. + // https://onnx.ai/onnx/operators/onnx__Conv.html + pad_x = pad_y = 0; + } + + return std::nullopt; +} + +void tileImageTensorByChannel(Value imageTensor, + SmallVector>>& tiles, + size_t tileSize, + ConversionPatternRewriter& rewriter) { + ShapedType imageShape = mlir::cast(imageTensor.getType()); + + size_t input_h = GET_IMAGE_HEIGHT(imageShape); + size_t input_w = GET_IMAGE_WIDTH(imageShape); + size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize); + size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize; + + SmallVector strides(4, rewriter.getIndexAttr(1)); + SmallVector offsets(4, rewriter.getIndexAttr(0)); + SmallVector sizes = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + + Location loc = imageTensor.getLoc(); + + for (size_t i = 0; i < tileCount; i++) { + if (i == tileCount - 1 && tileRest != 0) + sizes[1] = rewriter.getIndexAttr(tileRest); + for (size_t x = 0; x < input_w; x++) { + for (size_t y = 0; y < input_h; y++) { + offsets[1] = rewriter.getIndexAttr(i * tileSize); + offsets[2] = rewriter.getIndexAttr(x); + offsets[3] = rewriter.getIndexAttr(y); + + tiles[i][x][y] = rewriter.create(loc, imageTensor, offsets, sizes, strides); + } + } + } +} + +Value createImgConcatOp(SmallVector>>& outputTiles, + ConversionPatternRewriter& rewriter, + Location& loc, + Type outputType) { + // Populate the outputTiles for the concat in the given order: + // 1. Start top left pixel + // 2. Continue on its right pixel till the end of the row + // 3. Restart on the next row + size_t outputTileCount = outputTiles.size(); + size_t output_w = outputTiles[0].size(); + size_t output_h = outputTiles[0][0].size(); + SmallVector tilesToConcat; + tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); + for (size_t outX = 0; outX < output_h; outX++) + for (size_t outY = 0; outY < output_w; outY++) + for (size_t outTile = 0; outTile < outputTileCount; outTile++) + tilesToConcat.push_back(outputTiles[outTile][outX][outY]); + + return rewriter.create(loc, outputType, tilesToConcat); +} + +LogicalResult +verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) { + + if (inX < 0) { + assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding"); + return failure(); + } + + if (inY < 0) { + assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding"); + return failure(); + } + + if ((size_t) inX >= input_w || (size_t) inY >= input_h) { + assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds"); + assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds"); + return failure(); + } + + return success(); +} + +Value createExtractSliceImg(Value valToSlice, + size_t x, + size_t y, + size_t t, + size_t channelTileCount, + size_t channelTileRest, + size_t input_w, + size_t input_h, + PatternRewriter& rewriter) { + SmallVector strides(4, rewriter.getIndexAttr(1)); + SmallVector offsets(4, rewriter.getIndexAttr(0)); + SmallVector sizes = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + + if (t == channelTileCount - 1 && channelTileRest != 0) + sizes[1] = rewriter.getIndexAttr(channelTileRest); + + offsets[1] = rewriter.getIndexAttr(t * crossbarSize); + offsets[2] = rewriter.getIndexAttr(x); + offsets[3] = rewriter.getIndexAttr(y); + + return rewriter.create(valToSlice.getLoc(), valToSlice, offsets, sizes, strides); +} + +Value indexImgValue(Value v, + size_t x, + size_t y, + size_t t, + size_t channelTileCount, + size_t channelTileRest, + size_t input_w, + size_t input_h, + ConversionPatternRewriter& rewriter) { + + auto newV = rewriter.getRemappedValue(v); + if (newV) + v = newV; + + if (!v.getDefiningOp()) + return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter); + + if (auto computeOp = v.getDefiningOp()) { + // We found the computeOp that produces the tile we want, just return this + // value. + // TODO: Should we assert that x,y,t are zero? + assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero"); + return v; + } + + if (auto receiveOp = v.getDefiningOp()) { + // This is a receiveOp, just return its value which will be resolved later + assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero"); + return v; + } + + if (auto imgConcatOp = v.getDefiningOp()) { + auto imgConcatInput = imgConcatOp.getInputTile(x, y, t); + // TODO: Is this correct? + // Above we already index exactly the tile we want, so `x=y=t=0` in + // recursive call + + return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter); + } + + if (auto tensorConcatOp = v.getDefiningOp()) { + // This can be recursive. + // First, get the input tensors of the tensor.concatOp + // Then, find the input tensor that contains the tile we want + // Finally, recursive call asking for the tile + auto concatAxis = tensorConcatOp.getDim(); + assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis"); + assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis."); + SmallVector indexDims = {1, t * crossbarSize, x, y}; + + // Find the input tensor that contains the tile we want + size_t currentTile = 0; + for (auto concatInput : tensorConcatOp.getInputs()) { + auto concatInputShape = cast(concatInput.getType()); + assert(concatInputShape.getRank() == 4 && "Expecting an image tensor"); + auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis); + + if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) { + // This input tensor contains the tile we want + indexDims[concatAxis] -= currentTile; + if (indexDims[1] % crossbarSize != 0) { + assert(ignoreConcatError + && "TODO: Handle non-tile aligned tensor, or set " + "--ignore-concat-error=true"); + } + return indexImgValue(concatInput, + indexDims[2], + indexDims[3], + indexDims[1] / crossbarSize, + channelTileCount, + channelTileRest, + input_w, + input_h, + rewriter); + } + currentTile += concatInputSizeOnAxis; + } + + assert(false + && "Could not find the input tensor that contains the tile " + "within tensor.ConcatOp"); + } + + v.dump(); + + assert(false && "indexImgValue: unsupported operation"); +} + +void resolveInputTensorTilesBlockArg(Value wholeInputTensor, + SmallVector>>& inputTiles, + size_t channelTileCount, + size_t channelTileRest, + size_t input_w, + size_t input_h, + PatternRewriter& rewriter) { + SmallVector strides(4, rewriter.getIndexAttr(1)); + SmallVector offsets(4, rewriter.getIndexAttr(0)); + SmallVector sizes = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Location loc = wholeInputTensor.getLoc(); + + for (size_t t = 0; t < channelTileCount; t++) { + if (t == channelTileCount - 1 && channelTileRest != 0) + sizes[1] = rewriter.getIndexAttr(channelTileRest); + for (size_t x = 0; x < input_w; x++) { + for (size_t y = 0; y < input_h; y++) { + offsets[1] = rewriter.getIndexAttr(t * crossbarSize); + offsets[2] = rewriter.getIndexAttr(x); + offsets[3] = rewriter.getIndexAttr(y); + + inputTiles[t][x][y] = rewriter.create(loc, wholeInputTensor, offsets, sizes, strides); + } + } + } +} + +std::optional resolveImgInputTiles(Value wholeInputTensor, + SmallVector>>& inputTiles, + size_t channelTileCount, + size_t channelTileRest, + size_t input_w, + size_t input_h, + ConversionPatternRewriter& rewriter) { + + for (size_t t = 0; t < channelTileCount; t++) { + for (size_t x = 0; x < input_w; x++) { + for (size_t y = 0; y < input_h; y++) { + inputTiles[t][x][y] = + indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter); + } + } + } + + return std::nullopt; +} + +LogicalResult handleFlattenLikeOp(SmallVector>& inputTiles, + const size_t inputTilesCount, + const size_t lastInputTileDimension, + TensorType inputShape, + TensorType outputShape, + Value reshapeInput, + ConversionPatternRewriter& rewriter) { + // Only support reshape between an image and a vector (i.e. flatten) + if (inputShape.getRank() != 4 || outputShape.getRank() != 2) { + return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(), + "resolveVecInputTiles only supports reshapes from 4D to 2D tensors"); + } + + /* + * From a 4D tensor to a 2D tensor + */ + auto N = inputShape.getDimSize(0); + auto C = inputShape.getDimSize(1); + auto H = inputShape.getDimSize(2); + auto W = inputShape.getDimSize(3); + assert(N == 1 && "Only support N = 1 for image tensors"); + + for (size_t i = 0; i < inputTilesCount; i++) { + auto c = (i / (H * W)) % C; + // TODO: Is this correct? Or should I invert h and w? + auto w = (i / H) % W; + auto h = i % H; + + Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter); + + // Assert the shape of the tile, and reshape it + auto curTileShape = cast(curTile.getType()); + assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?"); + assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?"); + assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?"); + assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?"); + + // Reshape this pixel tensor into a vector, for compatibility with the + // rest + SmallVector newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)}; + auto shapeType = RankedTensorType::get({static_cast(newShapeVals.size())}, rewriter.getI64Type()); + Value shapeTensor = + rewriter.create(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); + auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType()); + auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor); + + size_t coreIndex = i / crossbarCountInCore; + inputTiles[coreIndex].push_back(reshapedCurTile); + } + + return success(); +} + +std::pair kernel_get_start_and_end( + int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) { + int64_t firstValid = std::ceil(static_cast(pad) / dilation) * dilation - pad; + int64_t start = std::max(firstValid, out_pos * stride - pad); + int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad); + + assert(start >= 0 && "Start position must be non-negative."); + assert(end >= 0 && "End position must be non-negative."); + return std::make_pair(start, end); +} + +void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) { + auto oldSegmentSizes = wcomputeOp->getAttrOfType(wcomputeOp.getOperandSegmentSizesAttrName()); + + auto newSegmentSizes = + DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment}); + + wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes); +} + +int getResultIndex(Operation* op, Value v) { + int resultNumber = -1; + for (auto result : op->getResults()) { + if (result == v) { + resultNumber = result.getResultNumber(); + break; + } + } + assert(resultNumber >= 0 && "Value not found in given operation's results."); + + return resultNumber; +} + +}; // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp new file mode 100644 index 0000000..32fd6ca --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp @@ -0,0 +1,262 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "llvm/Support/LogicalResult.h" + +#define DEFINE_MAP_OP(opname) opname, + +#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2) +#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3) +#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1) +#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0) +#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2) +#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3) +#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0) + +using namespace mlir; + +namespace onnx_mlir { + +const StringRef REPLICATION_ATTR_NAME = "replication_factor"; + +using HSliceId = size_t; +using CoreId = size_t; + +enum class MapOperations { + None, + ONNXSoftmaxOp, + ONNXReluOp, + ONNXLeakyReluOp, + ONNXExpOp +}; + +template > +constexpr C ceilIntegerDivide(A a, B b) { + static_assert(std::is_integral_v, "A must be an integer type"); + static_assert(std::is_integral_v, "B must be an integer type"); + C ac = static_cast(a); + C bc = static_cast(b); + return 1 + (ac - 1) / bc; +} + +template > +constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { + static_assert(std::is_integral_v, "A must be an integer type"); + static_assert(std::is_integral_v, "B must be an integer type"); + C ac = static_cast(a); + C bc = static_cast(b); + return {ceilIntegerDivide(ac, bc), ac % bc}; +} + +template +bool isVectorShape(const ArrayRef shape) { + return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); +} + +template +bool isMatrixShape(const ArrayRef shape) { + return shape.size() == 2; +} + +template +bool isHVectorShape(const ArrayRef shape) { + return shape.size() == 2 && shape[0] == 1; +} + +template +bool isVVectorShape(const ArrayRef shape) { + return shape.size() == 2 && shape[1] == 1; +} + +template +T getVectorLength(const ArrayRef shape) { + assert(isVectorShape(shape)); + return shape[0] != 1 ? shape[0] : shape[1]; +} + +inline auto getTensorShape(const Value tensor) { return cast(tensor.getType()).getShape(); } + +SmallVector sliceTensor( + const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); + +SmallVector +sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); + +DenseMap> +sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc); + +DenseMap>> tileMatrix( + Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc); + +tensor::SplatOp +broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc); + +Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter); + +Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input); + +/** + * Unpacks an optional pair vector into two size_t values. + * + * @param valuesArray The optional `mlir::ArrayAttr` containing the pair of + * values. + * @param value1 The reference to the first `size_t` variable to store the + * unpacked value. + * @param value2 The reference to the second `size_t` variable to store the + * unpacked value. + */ +void unpackOptionalPairVector(std::optional valuesArray, size_t& value1, size_t& value2); + +/** + * Unpacks the optional pads vector. + * + * @param valuesArray The optional array attribute containing the values. + * @param pad_x The output variable to store the value of pad_x. + * @param pad_y The output variable to store the value of pad_y. + * @param rewriter The rewriter to notify failure + * + * @return llvm::Optional The error message if the pads are invalid + */ +std::optional unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y); + +/** + * Tiles the image tensor by channel. + * + * This function takes an image tensor and tiles it into smaller tiles based on + * the channel dimension. The size of each tile is specified by the tileSize + * parameter. + * + * @param imageTensor The input image tensor (NxCxWxH) to be tiled. + * @param tiles The output tiles vector to store the tiled image tensors. + * @param tileSize The size of each tile. + * @param rewriter The ConversionPatternRewriter used for creating operations. + */ +void tileImageTensorByChannel(Value imageTensor, + SmallVector>>& tiles, + size_t tileSize, + ConversionPatternRewriter& rewriter); + +/** + * Creates an ImgConcatOp based on the given tiles. + * + * This function takes a 3-dimensional vector `outputTiles` representing the + * tiles to concatenate. The tiles are indexed by [tile][x][y]. + * + * @param outputTiles The tiles to concatenate. + * @param rewriter The ConversionPatternRewriter used for creating the + * ImgConcatOp. + * @param loc The location of the operation. + * @param outputType The type of the output tensor. + * + * @return The created ImgConcatOp. + */ +Value createImgConcatOp(SmallVector>>& outputTiles, + ConversionPatternRewriter& rewriter, + Location& loc, + Type outputType); + +/** + * @brief Verifies if the given input coordinates and padding values are within + * the bounds of the input tensor. + * + * @param input_w The width of the input tensor. + * @param input_h The height of the input tensor. + * @param inX The X-coordinate of the input. + * @param inY The Y-coordinate of the input. + * @param pad_x The padding value in the X-direction. + * @param pad_y The padding value in the Y-direction. + * @return LogicalResult Returns success if the coordinates and padding are + * within bounds, failure otherwise. + */ +LogicalResult +verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y); + +/** + * Resolves the tiling of the input tensor into smaller tiles. + * + * This function takes a whole input tensor and tiles it into smaller tiles + * using the provided parameters. The resulting tiles are stored in the + * `inputTiles` vector. + * Input tiles need to be indexed by: + * a. Channel Tile + * b. Pixel `x` position + * c. Pixel `y` position + * For example: inputTiles[channelTile][x][y] + * + * @param wholeInputTensor The whole input tensor to be tiled. + * @param inputTiles A vector of vectors of vectors of Values representing the + * tiles of the input tensor. The outermost vector represents + * the channels, the middle vector represents the rows, and + * the innermost vector represents the columns of the tiles. + * @param channelTileCount The number of tiles for the `channel` axis. + * @param channelTileRest The size of the last channelTile. Set as 0 if tiles + * fit exactly + * @param input_w The width of the input tensor. + * @param input_h The height of the input tensor. + * @param rewriter The ConversionPatternRewriter used for creating operations. + * + * @return std::optional An error message if the input tensor could + * not be resolved into tiles. + */ +std::optional resolveImgInputTiles(Value wholeInputTensor, + SmallVector>>& inputTiles, + size_t channelTileCount, + size_t channelTileRest, + size_t input_w, + size_t input_h, + mlir::ConversionPatternRewriter& rewriter); + +/** + * Computes the boundaries of an image kernel application. + * + * @param out_pos The position of the output element. + * @param input_width The width of the input image. + * @param krn_width The width of the kernel. + * @param stride The stride value. + * @param dilation The dilation value. + * @param pad The padding value. + * @return A pair of size_t values representing the start and end positions of + * the kernel application. + */ +std::pair kernel_get_start_and_end( + int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad); + +/** + * @brief Increment the `operandSegmentSizes` in the WeightedCompute operation + * for the `inputs` operand. + * + * This function increments the size of the `inputs` operand segment in the + * `operandSegmentSizes` of the given WeightedCompute operation by the specified + * increment. This is necessary when new operands are programmatically added to + * the WeightedCompute operation. + * + * @param wcomputeOp The WeightedCompute operation whose `operandSegmentSizes` + * is to be incremented. + * @param increment The value by which to increment the `inputs` operand segment + * size. + */ +void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment); + +/** + * @brief Finds the result index of the given operation that produces the + * specified value. + * + * This function takes an operation and a value, and returns the index of the + * result of the operation that corresponds to the given value. + * + * @param op Operation whose result index is to be found. + * @param v The value for which the result index is to be determined. + * @return The index of the result of the operation that produces the specified + * value. + */ +int getResultIndex(Operation* op, Value v); + +}; // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp new file mode 100644 index 0000000..ed22fbe --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -0,0 +1,131 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_os_ostream.h" + +#include +#include + +#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" +#include "ONNXToSpatialPass.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerOptions.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace spatial { + +void ONNXToSpatialPass::runOnOperation() { + llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n"; + + ModuleOp module = getOperation(); + MLIRContext* ctx = &getContext(); + + RewritePatternSet mergeActivationPatterns(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + mergeActivationPatterns.add(ctx); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(mergeActivationPatterns)))) + llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; + + IRRewriter rewriter(module); + func::FuncOp funcOp = *module.getOps().begin(); + if (annotateReplication(funcOp, rewriter).failed()) { + llvm::dbgs() << "Failed during annotation for replication analysis\n"; + signalPassFailure(); + return; + } + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (useExperimentalConvImpl) { + populateExperimentalTilingConvOpPattern(patterns, ctx); + populateExperimentalPoolingTilingPattern(patterns, ctx); + populateGemmToConvConversionPattern(patterns, ctx); + } + else { + populateTilingConvOpPattern(patterns, ctx); + populatePoolingTilingPattern(patterns, ctx); + populateTilingGemmOpPattern(patterns, ctx); + } + + populateONNXConcatToTensorConcatPattern(patterns, ctx); + populateReduceMeanConversionPattern(patterns, ctx); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + + // Count the number of compute ops and check they do not exceed the core count + if (coresCount != -1) { + int computeOpsCount = 0; + for (auto& op : funcOp.getFunctionBody().front().getOperations()) + if (isa(op)) + computeOpsCount++; + + if (computeOpsCount > coresCount) { + llvm::dbgs() << "Number of compute ops exceeds the core count\n"; + signalPassFailure(); + return; + } + } + + // Remove trailing "helper ops" i.e. concat,img_concat,reshape. + RewritePatternSet removeUnusedHelperOpsPatterns(ctx); + populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(removeUnusedHelperOpsPatterns)))) + llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; + + annotateWeightsConstants(funcOp); + + // Dump to file for debug + std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); + std::filesystem::create_directory(outputDir); + std::fstream file(outputDir + "/spatial.mlir", std::ios::out); + llvm::raw_os_ostream os(file); + os << *module; + os.flush(); + file.close(); +} + +void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { + MLIRContext* ctx = funcOp.getContext(); + funcOp.walk([&](arith::ConstantOp constantOp) { + bool isAlwaysWeight = + llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); + if (isAlwaysWeight) + constantOp->setAttr("weightAlways", UnitAttr::get(ctx)); + }); +} + +} // namespace spatial + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp new file mode 100644 index 0000000..8fc2c82 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "mlir/Pass/Pass.h" + +#include "src/Dialect/ONNX/ONNXOps.hpp" + +namespace onnx_mlir { + +using namespace mlir; +extern bool haveSameStaticShape(Value lhs, Value rhs); + +namespace spatial { + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" + +struct ONNXToSpatialPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) + StringRef getArgument() const override { return "convert-onnx-to-spatial"; } + StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } + + ONNXToSpatialPass() = default; + ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} + + void runOnOperation() override; + +private: + void annotateWeightsConstants(func::FuncOp funcOp) const; +}; + +} // namespace spatial + +std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp new file mode 100644 index 0000000..31fd82f --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp @@ -0,0 +1,40 @@ +#pragma once +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +void populateLoweringONNXMatMulOpToSpatialPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populateTilingGemmOpPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateTilingConvOpPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populatePoolingTilingPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populateDistributeReducePattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populateFoldComputePattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populateONNXConcatToTensorConcatPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populateRemoveUnusedHelperOpsPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +void populateReduceMeanConversionPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +// Experimental patterns. +void populateExperimentalTilingConvOpPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateGemmToConvConversionPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateExperimentalPoolingTilingPattern( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp b/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp new file mode 100644 index 0000000..fa93f91 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp @@ -0,0 +1,31 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +struct ONNXConcatToTensorConcat : public OpConversionPattern { + ONNXConcatToTensorConcat(MLIRContext *ctx) : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, + ONNXConcatOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto inputs = adaptor.getInputs(); + int64_t axis = adaptor.getAxis(); + + rewriter.replaceOpWithNewOp(maxpoolOp, axis, inputs); + + return success(); + } +}; + +void populateONNXConcatToTensorConcatPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp b/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp new file mode 100644 index 0000000..abc87ed --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp @@ -0,0 +1,34 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +template +struct RemoveUnusedHelperOps : public OpRewritePattern { + RemoveUnusedHelperOps(MLIRContext* ctx) + : OpRewritePattern(ctx) {} + + void initialize() { this->setHasBoundedRewriteRecursion(); } + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final { + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + return failure(); + } +}; + +void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp new file mode 100644 index 0000000..c335fd8 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp @@ -0,0 +1,119 @@ +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include + +using namespace mlir; + +namespace onnx_mlir { + +/** + * @brief Structure that describes the replication of a convolution operation, + * along the image height axis. + */ +struct ConvReplication { + ONNXConvOp convOp; // Convolution operation + size_t input_w; // Width of the input image + size_t replicationFactor; // Replication factor on the image height axis + size_t coresNeededPerReplica; // Number of cores needed for each replica + + friend bool operator<(const ConvReplication& a, const ConvReplication& b) { + return a.input_w / a.replicationFactor < b.input_w / b.replicationFactor; + } + + ConvReplication(ONNXConvOp convOp, size_t input_w, size_t replicationFactor, size_t coresNeededPerReplica) + : convOp(convOp), + input_w(input_w), + replicationFactor(replicationFactor), + coresNeededPerReplica(coresNeededPerReplica) {} +}; + +LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter) { + + if (coresCount == -1) { + // No need for annotation, implicitly set replication to 1 + return success(); + } + + std::priority_queue convOpsReplicationQueue; + + size_t minimumCores = 0; + + for (auto& op : funcOp.getFunctionBody().begin()->getOperations()) { + if (auto convOp = dyn_cast(op)) { + // Convolution layer + + Value X = convOp.getX(), W = convOp.getW(); + ShapedType xShape = mlir::cast(X.getType()); + ShapedType wShape = mlir::cast(W.getType()); + + size_t input_w = GET_IMAGE_WIDTH(xShape); + size_t krn_h = GET_KERNEL_HEIGHT(wShape); + size_t krn_w = GET_KERNEL_WIDTH(wShape); + + size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue()); + + auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; + auto neededCores = ceilIntegerDivide(neededXbars, crossbarCountInCore.getValue()); + + minimumCores += neededCores; + + convOpsReplicationQueue.emplace(convOp, input_w, 1, neededCores); + } + else if (auto gemmOp = dyn_cast(op)) { + // Fully connected layer + auto matrixTensorShape = cast(gemmOp.getB().getType()); + auto inputSize = matrixTensorShape.getDimSize(0); + auto outputSize = matrixTensorShape.getDimSize(1); + if (gemmOp.getTransB()) + std::swap(inputSize, outputSize); + + const size_t inputTilesCount = ceilIntegerDivide(inputSize, crossbarSize.getValue()); + const size_t outputTilesCount = ceilIntegerDivide(outputSize, crossbarSize.getValue()); + + // Each output tile is computed by `coresPerOutputTile` cores. The + // entire input is given to each of these cores. + const size_t coresPerOutputTile = ceilIntegerDivide(inputTilesCount, crossbarCountInCore.getValue()); + + auto neededCores = coresPerOutputTile * outputTilesCount; + + minimumCores += neededCores; + } + } + + if (static_cast(coresCount) < minimumCores) { + return funcOp->emitError("Not enough cores for this network: ") + << minimumCores << " cores needed, but only " << static_cast(coresCount) << " available."; + } + + size_t availableCores = static_cast(coresCount) - minimumCores; + + // Consume all the elements in the queue + while (!convOpsReplicationQueue.empty()) { + auto convOpReplication = convOpsReplicationQueue.top(); + convOpsReplicationQueue.pop(); + + // Check if we can replicate this convolution (e.g. we have enough cores) + if (availableCores > convOpReplication.coresNeededPerReplica * (convOpReplication.replicationFactor + 1)) { + // We can replicate this convolution: increment replicationFactor and put + // back in queue + availableCores -= convOpReplication.coresNeededPerReplica; + convOpReplication.replicationFactor++; + + convOpsReplicationQueue.push(convOpReplication); + } + else { + // Cannot replicate this convolution anymore, annotate the operation + // with the replication factor + convOpReplication.convOp->setAttr(REPLICATION_ATTR_NAME, + rewriter.getI64IntegerAttr(convOpReplication.replicationFactor)); + } + } + + return success(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp new file mode 100644 index 0000000..108aa07 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +mlir::LogicalResult annotateReplication( + mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter); + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp new file mode 100644 index 0000000..ad0267a --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp @@ -0,0 +1,382 @@ + +#include "SpatialReducer.hpp" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum) +#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum) + +namespace onnx_mlir { + +llvm::SmallPtrSet + onnx_mlir::SpatialReducer::oldComputeOpsReplaced; + +ResNum SpatialReducer::applyResultProcessing( + ComputeAndResNum computeOpAndResNum, + std::function processFun, + ConversionPatternRewriter &rewriter) { + assert(processFun); + + auto computeOp = GET_COMP(computeOpAndResNum); + auto resultNum = GET_RES_NUM(computeOpAndResNum); + + spatial::SpatYieldOp yieldOp = + cast(computeOp.getBody().front().getTerminator()); + + Value result = yieldOp->getOperand(resultNum); + rewriter.setInsertionPointAfterValue(result); + Value processedResult = processFun(result); + if (processedResult == result) { + // Sometimes we want processedResult to return the same value but do + // something else with it (e.g. in softmax we want to broadcast the value + // using a channel). In this case, we can just return the same value. + return resultNum; + } + + yieldOp->insertOperands(yieldOp->getNumOperands(), processedResult); + + return yieldOp.getNumOperands() - 1; +} + +OpAndResNum SpatialReducer::applyReducePattern( + SmallVector &computeOpsAndResNum, + std::function reduce, + std::function preprocess, + std::function postprocess) { + + if (preprocess) { + for (auto &computeOpAndResNum : computeOpsAndResNum) { + GET_RES_NUM(computeOpAndResNum) = + applyResultProcessing(computeOpAndResNum, preprocess, rewriter); + } + } + + // It is possible that `computeOpsAndResNum` contains two entries for the same + // computeOp. In this case, we need to apply the reduction within-computef + + // Keep a map between a computeOp and the last Value for this reduction + std::unordered_map lastValueForCompute; + for (auto &computeOpAndResNum : computeOpsAndResNum) { + auto computeOp = GET_COMP(computeOpAndResNum); + auto yieldOp = + cast(computeOp.getBody().front().getTerminator()); + Value valueWithinCompute = + yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); + + auto it = lastValueForCompute.find(computeOp.getOperation()); + + if (it != lastValueForCompute.end()) { + // If we have already seen this computeOp, apply the reduction + // within-compute + Value lastWithinComputeValue = it->second; + + assert(valueWithinCompute.getDefiningOp() && + lastWithinComputeValue.getDefiningOp()); + + if (valueWithinCompute.getDefiningOp()->isBeforeInBlock( + lastWithinComputeValue.getDefiningOp())) { + rewriter.setInsertionPointAfterValue(lastWithinComputeValue); + } else { + rewriter.setInsertionPointAfterValue(valueWithinCompute); + } + valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute); + lastValueForCompute[computeOp.getOperation()] = valueWithinCompute; + } + + lastValueForCompute[computeOp.getOperation()] = valueWithinCompute; + } + + // Now, reconstruct from the map the computeOpsAndResNum list + computeOpsAndResNum.clear(); + computeOpsAndResNum.reserve(lastValueForCompute.size()); + for (auto &entry : lastValueForCompute) { + auto computeOp = cast(entry.first); + auto valueWithinCompute = entry.second; + + // We check if `valueWithinCompute` is already used by the yieldOp, in that + // case no need to add it + auto yieldOp = + cast(computeOp.getBody().front().getTerminator()); + bool yieldOpUseFound = false; + for (auto &use : valueWithinCompute.getUses()) { + if (use.getOwner() == yieldOp.getOperation()) { + // If the value is already used by the yieldOp, we can just use it + computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()}); + yieldOpUseFound = true; + break; + } + } + if (yieldOpUseFound) { + continue; + } + + // If this result is not used within a yieldOp, then add it + auto resultNum = yieldOp->getNumOperands(); + yieldOp->insertOperands(resultNum, valueWithinCompute); + + computeOpsAndResNum.push_back({computeOp, resultNum}); + } + + Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc(); + + // Recursive algorithm to reduce the inputs to a single one: + // - Take two inputs at a time, and reduce them into a single one, updating + // the computeOpsAndResNum list which becomes half the size. + // - Repeat until there is only one input left. + llvm::OwningArrayRef computeOpsRef(computeOpsAndResNum); + while (computeOpsRef.size() > 1) { + SmallVector nextComputeOps; + nextComputeOps.reserve(computeOpsRef.size() / 2); + for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) { + auto [firstCompute, firstResultNum] = computeOpsRef[i]; + auto [secondCompute, secondResultNum] = computeOpsRef[i + 1]; + + if (secondCompute->isBeforeInBlock(firstCompute)) { + std::swap(firstCompute, secondCompute); + std::swap(firstResultNum, secondResultNum); + } + + // We do not immediately alter the computeOps results/operands, instead we + // do it in a delayed manner, to avoid invalidating the references to the + // computeOps (which must be replaced by a cloned ComputeOp when changing + // the number of results) + // See below `reducerChanges.push_back` and `finalizeReduceUpdates` + + auto yieldOpFirstCompute = cast( + firstCompute.getBody().front().getTerminator()); + + // Add a new operand to the block of the second computeOp + Block &secondBlock = secondCompute.getBody().front(); + Value formerRes1 = secondBlock.addArgument( + yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); + + auto secondComputeWeightsNum = + secondCompute->getAttrOfType( + secondCompute.getOperandSegmentSizesAttrName())[0]; + auto secondComputeOperandNum = + secondComputeWeightsNum + secondBlock.getNumArguments() - 1; + + // Take the "former-result" from the second computeOp + spatial::SpatYieldOp secondYield = + cast(secondBlock.getTerminator()); + Value formerRes2 = secondYield.getOperand(secondResultNum); + + // Apply reduction operation + rewriter.setInsertionPoint(secondYield); + Value reduced = reduce(formerRes2, formerRes1); + + // Unfortunately, it is not possible to update the result in place, + // because we may have already referenced it by + // outside of this function, thus replacing it would invalidate the + // reference. Therefore, we need to append a new result to the yieldOp, + // and then at a later stage update the computeOp accordingly. + + // Add `reduced` to the second yieldOp + auto secondYieldOperandNum = secondYield.getNumOperands(); + secondYield->insertOperands(secondYieldOperandNum, reduced); + secondResultNum = secondYieldOperandNum; + + // We should also add an entry for updating the results of the last + // operation (the one which never becomes a `firstCompute`): because it is + // not tracked by reducerChanges as `fromOp` + reducerChanges.push_back({firstCompute.getOperation(), firstResultNum, + secondCompute.getOperation(), secondComputeOperandNum}); + nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum)); + } + + // If we have an odd number of inputs, we need to add the last one to the + // newInputs list. + if (computeOpsRef.size() % 2 == 1) { + nextComputeOps.push_back(computeOpsRef.back()); + } + + // Replace the inputOps list with the new one. + computeOpsRef = + llvm::OwningArrayRef(std::move(nextComputeOps)); + } + + assert(computeOpsRef.size() == 1 && + "Internal error: expected a single input at this point."); + + auto finalComputeAndResNum = computeOpsRef[0]; + + // Force the update of the results of this computeOp, when finalizing + computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum)); + + if (postprocess) { + GET_RES_NUM(finalComputeAndResNum) = + applyResultProcessing(finalComputeAndResNum, postprocess, rewriter); + } + + return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), + GET_RES_NUM(finalComputeAndResNum)); +} + +void SpatialReducer::finalizeReduceUpdates() { + assert(reducesFinalized == false && "Cannot finalize two times."); + + reducesFinalized = true; + + // First, add the results to the computeOps + for (auto &reduceChange : reducerChanges) { + updateResultsOfCompute(reduceChange.fromOp); + } + + for (auto &c : computeOpNeedingResUpdate) { + updateResultsOfCompute(c.getOperation()); + } + + for (auto &reducerChange : this->reducerChanges) { + auto fromOp = reducerChange.fromOp; + auto toOp = reducerChange.toOp; + auto fromOpResNum = reducerChange.fromOpResNum; + auto toOpOperandNum = reducerChange.toOpOperandNum; + + auto fromComputeOp = opToReplacedCompute[fromOp]; + assert(fromComputeOp && "fromOp should have been mapped before!"); + + // toComputeOp could be the existing pointer, or we have to remap it with + // `opToReplacedCompute` + auto toComputeOp = opToReplacedCompute[toOp]; + if (!toComputeOp) { + toComputeOp = cast(toOp); + } + + assert(toComputeOp != fromComputeOp && + "Oops should have caught this earlier!"); + + assert(toComputeOp->getNumOperands() == toOpOperandNum && + "toOpOperandNum should be the last operand of toComputeOp, are the " + "operations in the right order?"); + + // Add the new operand to `toComputeOp` + auto fromResult = fromComputeOp.getResult(fromOpResNum); + toComputeOp->insertOperands(toOpOperandNum, fromResult); + incrementWeightedComputeInputsSegmentSize(toComputeOp, 1); + } +} + +Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum &opAndResNum) { + assert(reducesFinalized && + "Cannot create resolve values before finalizing the reduce updates."); + + Operation *opToCast; + auto it = opToReplacedCompute.find(opAndResNum.first); + if (it != opToReplacedCompute.end()) { + opToCast = it->second; + } else { + opToCast = opAndResNum.first; + } + + auto computeOp = cast(opToCast); + + return computeOp.getResult(opAndResNum.second); +} + +void SpatialReducer::updateResultsOfCompute(Operation *computeOp) { + if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { + // If we have already replaced the fromOp, we do not need to do it again + return; + } + auto oldComputeOp = cast(computeOp); + + auto oldComputeOpNum = oldComputeOp->getNumOperands(); + + auto yieldOp = + cast(oldComputeOp.getBody().front().getTerminator()); + + if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { + // No result was added, just add itself to the map + opToReplacedCompute[oldComputeOp.getOperation()] = oldComputeOp; + return; + } + + // Add the results by inspecting its YieldOp + auto newResultTypes = yieldOp.getOperandTypes(); + + // Create a new ComputeOp with the new result type, but same operands + rewriter.setInsertionPoint(oldComputeOp); + auto newComputeOp = + rewriter.create(oldComputeOp->getLoc(), + newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); + + newComputeOp.getBody().takeBody(oldComputeOp.getBody()); + + auto newComputeOpNum = newComputeOp->getNumOperands(); + + assert(oldComputeOpNum == newComputeOpNum); + + // Since we replaced the old ComputeOp with a new one, we need to replace + // all its results' uses + for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) { + Value oldResult = oldComputeOp.getResult(i); + Value newResult = newComputeOp.getResult(i); + + // Replace the uses, except the uses of the compute ops which got deleted + // previously + rewriter.replaceAllUsesExcept(oldResult, newResult, oldComputeOpsReplaced); + } + + // Finally, erase the old computeOp and update the map + opToReplacedCompute[oldComputeOp.getOperation()] = newComputeOp; + oldComputeOpsReplaced.insert(oldComputeOp.getOperation()); + rewriter.setInsertionPoint(oldComputeOp); + rewriter.eraseOp(oldComputeOp); +} + +Value SpatialReducer::createImgConcatOp( + SmallVector>> &outputTiles, + Location &loc, Type outputType) { + + assert(reducesFinalized && + "Cannot create ImgConcatOp before finalizing the reduce updates."); + + // outputTiles are indexed like this: [channelTile][x][y] + auto tilesCount = outputTiles.size(); + auto width = outputTiles[0].size(); + auto height = outputTiles[0][0].size(); + + SmallVector>> remappedOutputTiles(tilesCount, + SmallVector>(width, SmallVector(height))); + + for (size_t t = 0; t < tilesCount; t++) + for (size_t x = 0; x < width; x++) + for (size_t y = 0; y < height; y++) + remappedOutputTiles[t][x][y] = + resolveValueFromOpAndResNum(outputTiles[t][x][y]); + + return ::onnx_mlir::createImgConcatOp( + remappedOutputTiles, rewriter, loc, outputType); +} + +OpAndResNum SpatialReducer::applyAddMapReduction( + SmallVector &computeOps, + ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp) { + + std::function postprocessing = nullptr; + + if (mapOp != MapOperations::None) { + postprocessing = [&](const Value a) { + Value mapOperand = a; + if (biasTile) { + mapOperand = rewriter.create( + a.getLoc(), a.getType(), a, biasTile); + } + return createMapOperation(rewriter, mapOp, mapOperand); + }; + } + + return this->applyReducePattern( + computeOps, + [&](Value a, Value b) { + return rewriter.create(a.getLoc(), a.getType(), a, b); + }, + /* preprocess = */ nullptr, postprocessing); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp new file mode 100644 index 0000000..4dcfe67 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" + +namespace onnx_mlir { + +using ResNum = unsigned int; + +using ComputeAndResNum = std::pair; + +struct SpatialReducerChange { + Operation *fromOp; + unsigned int fromOpResNum; + Operation *toOp; + unsigned int toOpOperandNum; +}; + +using OpAndResNum = std::pair; + +class SpatialReducer { + +public: + SpatialReducer(ConversionPatternRewriter &rewriter) : rewriter(rewriter) {} + + OpAndResNum applyReducePattern( + SmallVector &computeOpsAndResNum, + std::function reduce, + std::function preprocess, + std::function postprocess); + + OpAndResNum applyAddMapReduction(SmallVector &computeOps, + ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp); + + void finalizeReduceUpdates(); + + ~SpatialReducer() { + if (!reducesFinalized) { + finalizeReduceUpdates(); + } + } + + Value createImgConcatOp( + llvm::SmallVector>> + &outputTiles, + Location &loc, Type outputType); + + Value resolveValueFromOpAndResNum(OpAndResNum &opAndResNum); + +private: + [[nodiscard("computeOp result number gets updated")]] ResNum + applyResultProcessing(ComputeAndResNum computeOpAndResNum, + std::function processFun, + ConversionPatternRewriter &rewriter); + + /** + * @brief Update the results of a ComputeOp. + * + * This function updates the results of a ComputeOp by taking a look at the + operands of its yieldOp. + * If the ComputeOp was replaced, it updates `opToReplacedCompute` with the + replaced ComputeOp. + * + * @param computeOp The ComputeOp to update the results of. + */ + void updateResultsOfCompute(Operation *computeOp); + + ConversionPatternRewriter &rewriter; + bool reducesFinalized = false; + + // List of changes to be applied after the reduction is finalized + SmallVector reducerChanges; + // List of computeOps that need to be replaced with new results + SmallVector computeOpNeedingResUpdate; + + std::unordered_map opToReplacedCompute; + + static llvm::SmallPtrSet oldComputeOpsReplaced; +}; + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp new file mode 100644 index 0000000..81314ac --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp @@ -0,0 +1,53 @@ +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" +#include + +namespace onnx_mlir { + +WeightSubdivider::WeightSubdivider( + map>> weights) + : weights(std::move(weights)) {} + +bool WeightSubdivider::isEmpty() const { return weights.empty(); } + +TaggedWeights WeightSubdivider::popGroup(size_t amount) { + assert(!weights.empty() && "No weights to extract."); + + auto it = weights.begin(); + SmallVector &values = it->second.begin()->second; + + long inputTile = it->first; + long outputTile = it->second.begin()->first; + + size_t n = std::min(amount, values.size()); + crossbarsUsed += n; + + SmallVector result; + result.assign(values.begin(), values.begin() + n); + + if (n < values.size()) { + values.erase(values.begin(), values.begin() + n); + } else { + it->second.erase(outputTile); + if (it->second.empty()) { + weights.erase(inputTile); + } + } + + return {inputTile, outputTile, crossbarsUsed - n, result}; +} + +SmallVector WeightSubdivider::popGroups(size_t n) { + crossbarsUsed = 0; + SmallVector result; + size_t remaining = n; + + while (remaining > 0 && !weights.empty()) { + auto group = popGroup(remaining); + result.push_back(group); + remaining -= group.weights.size(); + } + + return result; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp new file mode 100644 index 0000000..6e5c5f1 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace std; + +namespace onnx_mlir { + +/** + * @brief A helper struct to store a group of weights. + * + */ +struct TaggedWeights { + long inputTile; + long outputTile; + size_t startingCrossbarIndex; + SmallVector weights; +}; + +/** + * @brief A helper class to subdivide weights into groups. + * + * Weights are stored as a map of maps of SmallVectors. The outer map is indexed + * by input tile, the inner map is indexed by output tile, and the SmallVector + * contains the weights for the filter. This class allows us to extract groups + * of weights from the map until we've extracted a certain number of elements, + * namely as many as we need to fill a compute unit. + */ +class WeightSubdivider { +private: + map>> weights; + size_t crossbarsUsed = 0; + + TaggedWeights popGroup(size_t amount); + +public: + WeightSubdivider(map>> weights); + + bool isEmpty() const; + SmallVector popGroups(size_t n); +}; + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt b/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt new file mode 100644 index 0000000..1e849a4 --- /dev/null +++ b/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt @@ -0,0 +1,14 @@ +add_onnx_mlir_rewriter(SpatialToGraphviz) + +add_onnx_mlir_library(OMSpatialToGraphviz + SpatialToGraphviz.cpp + + LINK_LIBS PUBLIC + OMCompilerOptions + OMPIMCommon + OMONNXOps + SpatialOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_INCLUDE_PATH} +) diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp new file mode 100644 index 0000000..bdda694 --- /dev/null +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -0,0 +1,283 @@ +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/Format.h" + +#define FORMAT_OPERATION(op) \ + 'x' << llvm::format_hex_no_prefix(reinterpret_cast(op), 0) +#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) \ + llvm::format("Arg_%p_%u", computeOpPointer, argumentNum) + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct SpatialToGraphvizPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass) + + StringRef getArgument() const override { + return "convert-spatial-to-graphviz"; + } + + StringRef getDescription() const override { + return "Lower ONNX ops to Spatial ops."; + } + + SpatialToGraphvizPass(raw_ostream &os = llvm::errs()) : os(os) {} + SpatialToGraphvizPass(const SpatialToGraphvizPass &pass) + : SpatialToGraphvizPass(pass.os) {} + void runOnOperation() final; + +private: + raw_ostream &os; + + /** + * Draws the subgraph for a given spatial::SpatWeightedCompute, including: + * 1. Input nodes (block arguments) + * 2. Operations + * 3. Edges between yield (output) and its users + * + * @param op The spatial::SpatWeightedCompute to draw the subgraph for. + * @param computeNum The number of the compute operation. + */ + void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { + os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" + << computeNum << "\";\n" + << "\t\tstyle=filled;\n" + << "\t\tcolor=lightblue;\n"; + + Block &block = op.getBody().front(); + + // Inputs + size_t inputNum = 0; + for (BlockArgument &input : block.getArguments()) { + + auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum); + + os << "\t\t" << fromOp << " [label=\"Arg" << inputNum + << "\",shape=box];\n"; + for (auto userOp : input.getUsers()) { + os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; + } + inputNum++; + } + + // Iterate operations + for (auto &childOp : block.getOperations()) { + os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" + << childOp.getName() << "\"];\n"; + + drawEdgesFromOpToItsUsers(&childOp); + } + + os << "\t}\n"; + + // Draw edges from the yield to the users of this computeOp + Operation *yieldOp = block.getTerminator(); + if (!isa(yieldOp)) { + yieldOp->emitError("Terminator of block must be YieldOp ???"); + signalPassFailure(); + return; + } + + for (auto computeOpResult : op->getResults()) { + for (auto &computeOpUse : computeOpResult.getUses()) { + auto toOp = FORMAT_ARGUMENT( + computeOpUse.getOwner(), computeOpUse.getOperandNumber()); + os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n"; + } + } + } + + /** + * @brief Draws the subgraph for a concatOp. + * + * This function draws a subgraph for a concatOp. The subgraph consists of a + * node for each input of the concatOp, as well as an output node. Edges are + * created from the output node to each user of the concatOp. + * + * @param concatOp The concatOp for which the subgraph is drawn. + * @param concatOpNum The number of the concatOp. + */ + void drawConcatOpSubgraph(Operation *concatOp, size_t concatOpNum) { + os << "\tsubgraph clusterconcat" << concatOpNum + << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n" + << "\t\tstyle=filled;\n" + << "\t\tcolor=orange;\n"; + + // Inputs + size_t inputNum = 0; + for (Value input : concatOp->getOperands()) { + auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum); + + os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n"; + for (auto userOp : input.getUsers()) { + os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; + } + inputNum++; + } + + // Output + os << "\t\t" << FORMAT_OPERATION(concatOp) << " [label=Out];\n"; + + os << "\t}\n"; + + // Edges from output to users + + for (auto &computeOpUse : concatOp->getResult(0).getUses()) { + os << "\t" << FORMAT_OPERATION(concatOp) << " -> " + << FORMAT_ARGUMENT( + computeOpUse.getOwner(), computeOpUse.getOperandNumber()) + << ";\n"; + } + } + + /** + * Draws the ExtractSliceOp in the graph visualization. + * + * This function takes a tensor::ExtractSliceOp and adds the corresponding + * node and edges to the graph visualization. It creates a node with the + * label as the static offsets attribute of the sliceOp, and connects it to + * the compute operations that use the result of the sliceOp. + * + * @param sliceOp The tensor::ExtractSliceOp to be drawn in the graph + * visualization. + */ + void drawExtractSliceOp(tensor::ExtractSliceOp sliceOp) { + auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0); + os << "\t" << nodeId << " [label=\"Slice: "; + sliceOp.getStaticOffsetsAttr().print(os); + os << "\",color=lawngreen];\n"; + + for (auto &computeOpUse : sliceOp.getResult().getUses()) { + os << "\t" << nodeId << " -> " + << FORMAT_ARGUMENT( + computeOpUse.getOwner(), computeOpUse.getOperandNumber()) + << ";\n"; + } + } + + void drawBiasTileOp(tensor::ExtractSliceOp sliceOp) { + auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0); + os << "\t" << nodeId << " [label=\"Bias: "; + sliceOp.getStaticOffsetsAttr().print(os); + os << "\",color=lightpink];\n"; + + for (auto user : sliceOp.getResult().getUsers()) { + os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n"; + } + } + + /** + * Draws edges from the given operation to its users. + * + * @param fromOp The operation from which the edges are drawn. + */ + void drawEdgesFromOpToItsUsers(mlir::Operation *fromOp) { + for (auto result : fromOp->getResults()) { + for (auto userOp : result.getUsers()) { + os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " + << FORMAT_OPERATION(userOp) << ";\n"; + } + } + } + + /** + * Draws input node and edges for the given `funcOp`. + * + * @param funcOp The `funcOp` for which to draw input nodes and edges. + */ + void drawInputNodesAndEdges(func::FuncOp &funcOp) { + os << "\tinput [label=\"Module Input\",color=green];\n"; + + size_t funcOpArgNum = 0; + for (BlockArgument &arg : funcOp.getArguments()) { + + for (auto &useOp : arg.getUses()) { + os << "\tinput -> " + << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) + << "[label=" << funcOpArgNum << "];\n"; + } + funcOpArgNum++; + } + } +}; + +void SpatialToGraphvizPass::runOnOperation() { + ModuleOp module = getOperation(); + + // Get the first OP, must be a FuncOp + func::FuncOp func = *module.getOps().begin(); + if (!func) { + module->emitError("No FuncOp found in the begin of module"); + signalPassFailure(); + } + + os << "digraph G {\n" + << "\tnode [style=filled,color=white];\n"; + + size_t computeNum = 0; + size_t concatNum = 0; + + // Iterate over the ComputeOps within FuncOp: + // 1. Print their subgraph + // 2. Print the edges from its inputs to its outputs + for (Operation &op : func.getOps()) { + if (auto computeOp = dyn_cast(op)) { + drawComputeOpSubgraph(computeOp, computeNum++); + } else if (auto concatOp = dyn_cast(op)) { + drawConcatOpSubgraph(concatOp, concatNum++); + } else if (auto imgConcatOp = dyn_cast(op)) { + drawConcatOpSubgraph(imgConcatOp, concatNum++); + } else if (auto extractSliceOp = dyn_cast(op)) { + auto producerOp = extractSliceOp->getOperand(0).getDefiningOp(); + if (producerOp) { + // Skip extractSliceOp if producer is constant weights (ONNXConstantOp) + if (llvm::isa(producerOp)) { + continue; + } + // If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect + // directly to its user, which is not a ComputeOp argument. + if (llvm::isa(producerOp)) { + drawBiasTileOp(extractSliceOp); + continue; + } + } + + drawExtractSliceOp(extractSliceOp); + } + } + + // Draw input node, and edges to it users + drawInputNodesAndEdges(func); + + // Draw output node (use the return Operation - argument number=0 - as nodeId) + auto returnOp = func.getBody().front().getTerminator(); + os << '\t' << FORMAT_ARGUMENT(returnOp, 0) + << " [label=\"Module Output\",color=green];\n"; + + os << "}\n"; +} + +} // namespace + +std::unique_ptr createSpatialToGraphvizPass() { + return std::make_unique(); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt b/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt new file mode 100644 index 0000000..1e66336 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td) +mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") +add_public_tablegen_target(SpatialToPIMIncGen) + +add_onnx_mlir_library(OMSpatialToPIM + SpatialToPIMPass.hpp + SpatialToPIMPass.cpp + SpatialToPIMCommon.cpp + + DEPENDS + SpatialToPIMIncGen + + LINK_LIBS PUBLIC + OMCompilerOptions + OMPIMCommon + SpatialOps + PimOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_INCLUDE_PATH} +) diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td b/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td new file mode 100644 index 0000000..d6e2e1a --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td @@ -0,0 +1,28 @@ +#ifndef SPATIAL_TO_PIM +#define SPATIAL_TO_PIM + +#ifndef OP_BASE +include "mlir/IR/PatternBase.td" +include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" +include "src/Accelerators/PIM/Dialect/PIM/Pim.td" +#endif // OP_BASE + +def spatToPimVMMOp : Pat< + (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), + (PimVMMOp $weightIndex, $vector, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + +def spatToPimMVMOp : Pat< + (SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector), + (PimMVMOp $weightIndex, $vector, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + +def spatToPimVAddOp : Pat< + (SpatVAddOp:$srcOpRes $a, $b), + (PimVAddOp $a, $b, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + +#endif // SPATIAL_TO_PIM \ No newline at end of file diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp new file mode 100644 index 0000000..6d2e319 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp @@ -0,0 +1,97 @@ +#include "mlir/IR/ValueRange.h" + +#include "llvm/ADT/STLExtras.h" + +#include +#include + +#include "SpatialToPIMCommon.hpp" + +using namespace llvm; + +namespace onnx_mlir { + +size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) { + /* + EXAMPLE RUN: + [1, 10, 3, 4] inputShape + [0, 2, 1, 3] offsets + + acc = 1 + --- + ret = 3 + acc = 4 + --- + ret = 3 + 4 * 1 = 7 + acc = 12 + --- + ret = 7 + 12 * 2 = 31 + acc = 120 + --- + ret = 31 + 120 * 0 = 31 + acc = 120 + */ + + size_t returnValue = 0; + + auto sliceOffsets = sliceOp.getStaticOffsets(); + auto inputDimSizes = inputShape.getShape(); + + assert(sliceOffsets.size() == inputDimSizes.size()); + + size_t accumulatedDimensionSize = 1; + + // Reverse iterate the two vectors + for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) { + auto curSliceOffset = std::get<0>(it); + auto curInputDimSize = std::get<1>(it); + + returnValue += accumulatedDimensionSize * curSliceOffset; + accumulatedDimensionSize *= curInputDimSize; + } + + return returnValue; +} + +Operation* getEarliestUserWithinBlock(Value value) { + auto users = value.getUsers(); + + assert(!users.empty()); + + Operation* earliestUser = *users.begin(); + for (auto curUser : users) + if (curUser->isBeforeInBlock(earliestUser)) + earliestUser = curUser; + + return earliestUser; +} + +SmallVector getOpOperandsSortedByUses(Operation* operation) { + auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair { + return {operand, std::distance(operand.use_begin(), operand.use_end())}; + }); + sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; }); + return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); +} + +Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) { + assert("Only support operations with a single result" && operation->getNumResults() == 1); + Value result = operation->getResult(0); + auto resultType = result.getType(); + assert("Only support result ShapedType as result type" && isa(resultType)); + + SmallVector operands = getOpOperandsSortedByUses(operation); + auto validOperands = + make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; }); + auto bestOperand = validOperands.begin(); + + if (bestOperand != validOperands.end()) + return *bestOperand; + + auto resultShapedType = cast(resultType); + rewriter.setInsertionPoint(operation); + return rewriter.create( + operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp new file mode 100644 index 0000000..fcca029 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp @@ -0,0 +1,108 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +/** + * \brief Get the offset of the ExtractSliceOp based on its static offsets and + * its static tensor input. + * + * The static offsets represent the starting position of the slice in each + * dimension, while the static tensor input gives its dimension size. + * + * \param sliceOp The ExtractSliceOp for which the actual offset needs to be + * calculated. + * \param inputShape The ShapedType of the ExtractSliceOp's input tensor + * \return The actual offset of the ExtractSliceOp. + */ +size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape); + +template +size_t rangeLength(const iterator_range range) { + return std::distance(range.begin(), range.end()); +} + +/** + * Retrieves the earliest operation that uses the given value within the value's + * block. + * + * @param value The value for which to find the earliest user operation. + * @return The earliest user operation that uses the given value within the + * current block. + */ +Operation* getEarliestUserWithinBlock(Value value); + +SmallVector getOpOperandsSortedByUses(Operation* operation); + +Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation); + +static bool isMemoryContiguous(const ArrayRef srcShape, + const ArrayRef offsets, + const ArrayRef sizes, + const ArrayRef strides) { + // Check that all strides are 1 + if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) + return false; + + // Check offsets from right to left: + // The first offset_n at position n different from 0: + // - limits all sizes to the left to 1 + // - limits size_n to dimension_n - offset_n + auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), + llvm::make_range(sizes.rbegin(), sizes.rend()), + llvm::make_range(srcShape.rbegin(), srcShape.rend())); + + auto firstNonZeroOffset = std::find_if( + offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { + auto [offset, _size, _dimension] = offsetAndSizeAndShape; + return offset != 0; + }); + + if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { + auto [offset, size, dimension] = *firstNonZeroOffset; + if (size > dimension - offset) + return false; + ++firstNonZeroOffset; + + if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { + auto [_offset, size, _dimension] = offsetAndSizeAndShape; + return size != 1; + })) + return false; + } + + // Check sizes from right to left: + // The first size_n at position n different from shape_n limits all sizes to the left to 1 + auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), + llvm::make_range(srcShape.rbegin(), srcShape.rend())); + + auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { + auto [size, dimension] = sizeAndShape; + return size != dimension; + }); + + if (firstDifferentSize != sizesAndShape.end()) { + ++firstDifferentSize; + + if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { + auto [size, _] = sizeAndShape; + return size != 1; + })) + return false; + } + + return true; +} + +inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) { + return rewriter.create(loc, shapedType.getShape(), shapedType.getElementType()); +} + +inline bool isAConcatOp(Operation* op) { return isa(op) || isa(op); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp new file mode 100644 index 0000000..0d1b8fe --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -0,0 +1,491 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_os_ostream.h" + +#include +#include +#include +#include +#include + +#include "SpatialToPIMPass.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace pim { + +void SpatialToPIMPass::runOnOperation() { + coreId = 0; + ModuleOp moduleOp = getOperation(); + MLIRContext* ctx = moduleOp.getContext(); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + + func::FuncOp funcOp = *moduleOp.getOps().begin(); + if (!funcOp) + llvm_unreachable("No FuncOp found in the begin of module"); + + IRRewriter rewriter(&getContext()); + auto returnOp = cast(funcOp.front().getTerminator()); + + addResultBuffer(returnOp, rewriter); + allocateAndInitializeCoreLocalVariables(funcOp, rewriter); + + for (auto receiveOp : funcOp.getOps()) { + operationsToRemove.push_back(receiveOp); + runOnReceiveOp(receiveOp, rewriter); + } + for (auto computeOp : funcOp.getOps()) { + operationsToRemove.push_back(computeOp); + runOnComputeOp(computeOp, rewriter); + } + + enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); + replaceReturnOpOperands(returnOp, rewriter); + + // Remove all ComputeOps + for (auto opToRemove : llvm::reverse(operationsToRemove)) { + if (!opToRemove->use_empty()) { + opToRemove->dump(); + for (auto user : opToRemove->getUsers()) + user->dump(); + assert(false && "opToRemove should be unused at this point"); + } + rewriter.eraseOp(opToRemove); + } + + // Dump to file for debug + std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); + std::filesystem::create_directory(outputDir); + std::fstream file(outputDir + "/pim.mlir", std::ios::out); + llvm::raw_os_ostream os(file); + os << *moduleOp; + os.flush(); + file.close(); +} + +void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { + Location loc = computeOp->getLoc(); + + auto& block = computeOp.getRegion().front(); + auto yieldOp = cast(block.getTerminator()); + + if (computeOp.getNumResults() != yieldOp.getNumOperands()) + llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); + + for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { + // If this result has no uses, then just skip it + if (result.use_empty()) + continue; + + auto yieldType = cast(yieldValue.getType()); + + /* + * Here we assume that ReturnOp are only reachable by the following patterns: + * + * 1) + * %0 = spat.compute([...]) + * [%0 has one user, which is a ConcatOp] + * %1 = tensor.concat(%0) + * [%1 has one user, which is a ReturnOp] + * return %1 + * + * 2) + * %0 = spat.compute([...]) + * [%0 has one user, which is a ReturnOp] + * return %0 + * + * If the IR is like 2), then we can store the tensor to the output global memory location + */ + auto resultUses = result.getUses(); + auto numResultUses = rangeLength(resultUses); + if (numResultUses == 1) { + OpOperand& resultUse = *resultUses.begin(); + Operation* resultUser = resultUse.getOwner(); + + if (isa(resultUser)) { + size_t resultIndexInReturn = resultUse.getOperandNumber(); + size_t offset = 0; + size_t numElements = yieldType.getNumElements(); + size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; + + // Store to global memory + Value outputTensor = outputTensors[resultIndexInReturn]; + rewriter.setInsertionPointAfterValue(yieldValue); + rewriter.create(loc, + outputTensor.getType(), + outputTensor, + yieldValue, + rewriter.getI32IntegerAttr(offset), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(numElements * elementSize)); + continue; + } + + if (isa(resultUser) || isa(resultUser)) { + auto concatOp = resultUser; + auto concatValue = concatOp->getResult(0); + auto concatUses = concatValue.getUses(); + auto numConcatUses = rangeLength(concatUses); + if (numConcatUses == 1) { + OpOperand& concatUse = *concatUses.begin(); + Operation* concatUser = concatUse.getOwner(); + if (isa(concatUser)) { + size_t concatIndexInReturn = concatUse.getOperandNumber(); + size_t resultIndexInConcat = resultUses.begin()->getOperandNumber(); + size_t offset = 0; + for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat)) + offset += cast(operand.getType()).getNumElements() * cast(operand.getType()).getElementTypeBitWidth() / 8; + + size_t elementSize = yieldType.getElementTypeBitWidth() / 8; + + // Store to global memory + Value outputTensor = outputTensors[concatIndexInReturn]; + rewriter.setInsertionPointAfterValue(yieldValue); + rewriter.create( + loc, + outputTensor.getType(), + outputTensor, + yieldValue, + rewriter.getI32IntegerAttr(offset), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize)); + continue; + } + } + } + } + + // If this pattern was not found, then create a channel and send the value + + // 1. Create a new ChannelOp + rewriter.setInsertionPoint(computeOp); + auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); + auto channelOp = rewriter.create(loc, channelType); + + // 2. Receive value through the channel + // If this result is used by more than one user, then use a "Broadcast" + // channel operation. However, there is a special case: we have a single + // user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this + // case, we need to use a "Broadcast" channel operation. `addReceiveOps` + // will detect this case and update `useBroadcastOp` accordingly. + bool useBroadcastOp = (numResultUses > 1); + addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter); + + // 3. Send the value through the channel + rewriter.setInsertionPointAfterValue(yieldValue); + if (useBroadcastOp) + rewriter.create(loc, channelOp, yieldValue); + else + rewriter.create(loc, channelOp, yieldValue); + } + + // Use `HaltOp` instead of `YieldOp` + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp); + + // Replace `spat.compute` with `pim.core` + rewriter.setInsertionPointAfter(computeOp); + auto coreOp = rewriter.create(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); + auto& coreOpBlocks = coreOp.getBody().getBlocks(); + block.eraseArguments(0, block.getNumArguments()); + coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); + Block* tempComputeBlock = new Block(); + computeOp.getBody().push_back(tempComputeBlock); + rewriter.setInsertionPointToEnd(tempComputeBlock); + rewriter.create(computeOp.getLoc()); +} + +void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { + auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { + auto* definingOp = value.getDefiningOp(); + if (!definingOp) + return; + auto dpsDefiningOp = dyn_cast(definingOp); + if (!dpsDefiningOp) + return; + auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); + if (!tiedOperand) + return; + Value tiedValue = tiedOperand->get(); + assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); + tiedValue.setType(newType); + self(tiedValue, newType, self); + }; + + funcOp.walk([&](PimVMMOp vmmOp) { + auto outTensorOperand = vmmOp.getOutBuf(); + auto resultTensor = vmmOp.getOutRes(); + auto outShape = getTensorShape(outTensorOperand); + assert(isHVectorShape(outShape)); + if (outShape[1] != static_cast(crossbarSize)) { + auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; + auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); + enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); + outTensorOperand.setType(newType); + resultTensor.setType(newType); + + IntegerAttr zeroAttr = rewriter.getIndexAttr(0); + IntegerAttr oneAttr = rewriter.getIndexAttr(1); + IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); + IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); + SmallVector offsets = {zeroAttr, zeroAttr}; + SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; + SmallVector strides = {oneAttr, oneAttr}; + rewriter.setInsertionPointAfter(vmmOp); + auto sliceOp = rewriter.create(vmmOp.getLoc(), resultTensor, offsets, sizes, strides); + SmallPtrSet exceptions = {vmmOp, sliceOp}; + resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); + } + }); +} + +void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { + outputTensors.reserve(returnOp->getNumOperands()); + rewriter.setInsertionPointToStart(returnOp->getBlock()); + for (auto returnValue : returnOp->getOperands()) { + auto newOutputTensor = + createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast(returnValue.getType())); + outputTensors.push_back(newOutputTensor); + } +} + +void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { + Location loc = funcOp.getLoc(); + + auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { + auto tensorType = cast(valueToReplace.getType()); + Type elementType = tensorType.getElementType(); + size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; + rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); + + auto deviceTensor = rewriter.create(loc, tensorType.getShape(), elementType); + + auto memCopyHostToDevOp = rewriter.create( + loc, + tensorType, + deviceTensor, + hostTensor, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), + rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); + + rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); + }; + + // Replace input tensors with memRefs + SmallVector inputTensors; + for (size_t i = 0; i < funcOp.getNumArguments(); i++) { + BlockArgument tensorArg = funcOp.getArgument(i); + DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i); + ShapedType tensorArgType = cast(tensorArg.getType()); + MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); + + funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc); + BlockArgument memRefArg = funcOp.getArgument(i + 1); + + Block& block = funcOp.getBody().front(); + rewriter.setInsertionPoint(&block.front()); + auto toTensorOp = rewriter.create(loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); + inputTensors.push_back(toTensorOp); + + tensorArg.replaceAllUsesWith(toTensorOp); + funcOp.eraseArgument(i); + } + + llvm::SmallSet sliceOpsToRemove; + for (auto& op : funcOp.getBody().getOps()) + if (auto computeOp = dyn_cast(op)) { + unsigned numComputeWeights = computeOp.getWeights().size(); + for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { + TypedValue tensorSource; + int64_t elementsOffset = 0; + + if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { + tensorSource = cast>(sliceOp.getSource()); + + ArrayRef sourceShape = tensorSource.getType().getShape(); + ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); + ArrayRef sliceSizes = sliceOp.getStaticSizes(); + ArrayRef sliceStrides = sliceOp.getStaticStrides(); + assert("Extracting slice non-contiguous in memory" + && isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides)); + + for (size_t i = 0; i < sliceOffsets.size(); i++) { + int64_t partialOffset = sliceOffsets[i]; + if (partialOffset != 0) + for (size_t j = i + 1; j < sourceShape.size(); j++) + partialOffset *= sourceShape[j]; + elementsOffset += partialOffset; + } + + computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource); + sliceOpsToRemove.insert(sliceOp); + } + else + tensorSource = cast>(computeOpInput); + + // Compute results must be transferred through channels via send/receive + if (isa(tensorSource.getDefiningOp())) + continue; + + BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); + insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset); + } + } + + for (auto sliceOp : sliceOpsToRemove) + if (sliceOp->getUses().empty()) + rewriter.eraseOp(sliceOp); +} + +void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, + unsigned int argIndex, + spatial::SpatChannelNewOp& channel, + Type& tensorType, + bool useBroadcastOp, + IRRewriter& rewriter) { + auto& computeBlock = computeOp.getRegion().front(); + //(remember that WeightedCompute have weights as first operands, however these + // weights are not included in the block arguments. Thus, when indexing the + // block argument we need to remove the weights count) + auto computeWeightsCount = computeOp.getWeights().size(); + auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount); + // Receive the tensor just before the first use of the value + rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); + Value receivedValue; + if (useBroadcastOp) + receivedValue = rewriter.create(computeOp.getLoc(), tensorType, channel); + else + receivedValue = rewriter.create(computeOp.getLoc(), tensorType, channel); + + blockArg.replaceAllUsesWith(receivedValue); +} + +void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp, + spatial::SpatChannelNewOp& channel, + Type& channelTensorType, + bool& useBroadcastOp, + IRRewriter& rewriter) { + auto sourceOpUses = channelSourceOp.getUses(); + + // Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users + if (useBroadcastOp == false) { + // if useBroadcastOp is false, then sourceOp must have only one user + assert(rangeLength(sourceOpUses) == 1); + + if (auto reshapeOp = dyn_cast(sourceOpUses.begin()->getOwner())) { + auto reshapeOpUses = reshapeOp.getOutput().getUses(); + auto reshapeOpUsesCount = rangeLength(reshapeOpUses); + if (reshapeOpUsesCount > 1) + useBroadcastOp = true; + } + } + + for (auto& resultUse : sourceOpUses) { + // The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps + spatial::SpatWeightedCompute computeUser = dyn_cast(resultUse.getOwner()); + + if (computeUser) { + replaceBlockArgumentWithRecvOp( + computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); + continue; + } + + if (!computeUser) { + auto reshapeOp = dyn_cast(resultUse.getOwner()); + if (!reshapeOp) { + resultUse.getOwner()->dump(); + llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); + } + + // The tensorType now becomes the one of the reshapeOp + channelTensorType = reshapeOp.getResult().getType(); + + for (auto& reshapeUse : reshapeOp.getOutput().getUses()) { + computeUser = dyn_cast(reshapeUse.getOwner()); + + if (!computeUser) + llvm_unreachable("ReshapeOp users must be ComputeOps"); + + replaceBlockArgumentWithRecvOp( + computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); + } + + // Remove the reshapeOp, so that the sourceOp has no users + operationsToRemove.push_back(reshapeOp); + } + } +} + +void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { + for (auto it : llvm::enumerate(returnOp.getOperands())) { + Operation* returnOperand = it.value().getDefiningOp(); + + size_t orderWithinReturn = it.index(); + + rewriter.modifyOpInPlace(returnOp, + [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); + + // If the operand is a concatenation operation and the returnOp was the only + // user of the returnOperand, we can safely remove it + if (isAConcatOp(returnOperand)) { + auto returnOperandUses = it.value().getUses(); + if (rangeLength(returnOperandUses) == 0) + rewriter.eraseOp(returnOperand); + } + } +} + +void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { + + auto channel = cast(receiveOp.getChannel().getDefiningOp()); + + auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter); + if (failed(sendOpOpt)) + llvm_unreachable("ChannelReceiveOp has no matching SendOp"); + + auto sendOp = cast(*sendOpOpt); + + auto tensorType = receiveOp.getType(); + Value receiveRes = receiveOp.getResult(); + + // Check if the receiveOp value has more than one user + auto receiveUses = receiveRes.getUses(); + auto receiveUsesCount = rangeLength(receiveUses); + assert(receiveUsesCount > 0); + bool useBroadcastOp = receiveUsesCount > 1; + addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter); + + if (useBroadcastOp) { + // When receiving, we actually noticed that the value has more than one + // user. This means that we need to get the replace the original SendOp with + // a BroadcastSendOp + rewriter.setInsertionPoint(sendOp); + rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getData()); + } +} + +} // namespace pim + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp new file mode 100644 index 0000000..37f0826 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerOptions.hpp" + +namespace onnx_mlir { + +namespace pim { + +#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" + +struct SpatialToPIMPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass) + StringRef getArgument() const override { return "convert-spatial-to-pim"; } + StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } + + SpatialToPIMPass() = default; + SpatialToPIMPass(const SpatialToPIMPass& pass) {} + + void runOnOperation() final; + +private: + SmallVector outputTensors; + size_t coreId = 0; + SmallVector operationsToRemove; + + void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); + + void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); + + void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); + void addReceiveOps(Value& channelSourceOp, + spatial::SpatChannelNewOp& channel, + Type& channelTensorType, + bool& useBroadcastOp, + IRRewriter& rewriter); + void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, + unsigned int argIndex, + spatial::SpatChannelNewOp& channel, + Type& tensorType, + bool useBroadcastOp, + IRRewriter& rewriter); + + void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); + + void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); + + void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); +}; + +} // namespace pim + +std::unique_ptr createSpatialToPIMPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp new file mode 100644 index 0000000..b972265 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp @@ -0,0 +1,12 @@ +#pragma once +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +namespace spatial { + +// TODO: Add here eventual patterns + +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Dialect/CMakeLists.txt b/src/PIM/Dialect/CMakeLists.txt new file mode 100644 index 0000000..15cf316 --- /dev/null +++ b/src/PIM/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(PIM) +add_subdirectory(Spatial) \ No newline at end of file diff --git a/src/PIM/Dialect/PIM/CMakeLists.txt b/src/PIM/Dialect/PIM/CMakeLists.txt new file mode 100644 index 0000000..29f995e --- /dev/null +++ b/src/PIM/Dialect/PIM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_onnx_mlir_dialect(Pim pim) +add_onnx_mlir_dialect_doc(pim Pim.td) + + +add_onnx_mlir_library(PimOps + PimOps.cpp + Transforms/PimBufferizableOpInterface.cpp + + DEPENDS + OMPimIncGen + + LINK_LIBS PUBLIC + OMMlirDialects + MLIRIR +) diff --git a/src/PIM/Dialect/PIM/Pim.td b/src/PIM/Dialect/PIM/Pim.td new file mode 100644 index 0000000..fb9ac57 --- /dev/null +++ b/src/PIM/Dialect/PIM/Pim.td @@ -0,0 +1,345 @@ +#ifndef PIM_DIALECT_H +#define PIM_DIALECT_H + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/MemRef/IR/MemRefBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td" + +def PimDialect : Dialect { + let name = "pim"; + let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars"; + let cppNamespace = "::onnx_mlir::pim"; +} + +// Base class for Pim dialect operations. This operation inherits from the +// base `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class PimOp traits = []> : + Op; + +def PimTensor : + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + +//===----------------------------------------------------------------------===// +// Communication operations +//===----------------------------------------------------------------------===// + +def PimSendOp: PimOp<"send", []> { + let arguments = (ins + PimTensor: $src, + I32Attr: $size, + I32Attr: $targetCoreId + ); + + let assemblyFormat = [{ + `(` $src `)` attr-dict `:` type($src) `->` `(` `)` + }]; +} + +def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> { + let arguments = (ins + PimTensor: $dst, + I32Attr: $size, + I32Attr: $srcCoreId + ); + + let results = (outs + PimTensor: $out + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getDstMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $dst `)` attr-dict `:` type($dst) `->` type($out) + }]; +} + +//===----------------------------------------------------------------------===// +// Core operations +//===----------------------------------------------------------------------===// + +def PimCoreOp: PimOp<"core", [SingleBlock]> { + + let regions = (region SizedRegion<1>:$body); + + let arguments = (ins + Variadic:$weights, + I32Attr: $coreId + ); + + let assemblyFormat = [{ + `(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)` + }]; +} + +//===----------------------------------------------------------------------===// +// Memory Operations +//===----------------------------------------------------------------------===// + +def PimConstantOp: PimOp<"constant", []> { + let description = [{ + Allocate a constant value in global memory + }]; + + let arguments = (ins + AnyAttr: $value, + BoolAttr: $shouldAllocate + ); + + let results = (outs + PimTensor: $out + ); +} + +def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> { + let description = [{ + Copy a memory region from host memory into device memory + }]; + + let arguments = (ins + PimTensor: $deviceDst, + PimTensor: $hostSrc, + I32Attr: $deviceDstOffset, + I32Attr: $hostSrcOffset, + I32Attr: $size + ); + + let results = (outs + PimTensor: $deviceDstOut + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getDeviceDstMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut) + }]; +} + +def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> { + let description = [{ + Copy a memory region from device memory into host memory + }]; + + let arguments = (ins + PimTensor: $hostDst, + PimTensor: $deviceSrc, + I32Attr: $hostDstOffset, + I32Attr: $deviceSrcOffset, + I32Attr: $size + ); + + let results = (outs + PimTensor: $hostDstOut + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getHostDstMutable(); + } + }]; + + + let assemblyFormat = [{ + `(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut) + }]; +} + +//===----------------------------------------------------------------------===// +// Core.Compute operations +//===----------------------------------------------------------------------===// + +def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { + let description = [{ + Vector-matrix multiplication: c = a * b + }]; + + let arguments = (ins + I32Attr: $weightIndex, + PimTensor: $vectorInput, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutBufMutable(); + } + }]; +} + +def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> { + let description = [{ + Matrix-vector multiplication: c = a * b + }]; + + let arguments = (ins + I32Attr: $weightIndex, + PimTensor: $vectorInput, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutBufMutable(); + } + }]; +} + +def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> { + let description = [{ + Element-wise addition: c = a + b + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $b, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutBufMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) + }]; +} + +def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods]> { + let description = [{ + Element-wise max: c = max(a, b) + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $b, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); +} + +def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods]> { + let description = [{ + Apply filters to a tensor + }]; + + let arguments = (ins + I64ArrayAttr: $weightIndices, + I64ArrayAttr: $xKernelPositions, + I64ArrayAttr: $yKernelPositions, + PimTensor: $input, + PimTensor: $outBuf, + PimTensor: $accumBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let assemblyFormat = [{ + `(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:` + type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes) + }]; +} + +def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods]> { + let description = [{ + Sum all elements into a single one + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); +} + +def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods]> { + let description = [{ + Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience) + }]; + + let arguments = (ins + PimTensor: $dividend, + PimTensor: $divisor, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); +} + +def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods]> { + let description = [{ + Element-wise ReLU: c = max(a, 0) + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); +} + +def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods]> { + let description = [{ + Element-wise exp: c = exp(a) + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); +} + +def PimHaltOp: PimOp<"halt", [Terminator]> { + let description = [{ + Halts the execution of the core + }]; + + let assemblyFormat = [{ + attr-dict + }]; +} + +#endif // PIM_DIALECT_H \ No newline at end of file diff --git a/src/PIM/Dialect/PIM/PimOps.cpp b/src/PIM/Dialect/PIM/PimOps.cpp new file mode 100644 index 0000000..ebaed4a --- /dev/null +++ b/src/PIM/Dialect/PIM/PimOps.cpp @@ -0,0 +1,49 @@ +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" + +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace pim { + +void PimDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" + + >(); +} + +#define POPULATE_DEPENDENCIES(OP_NAME) \ + void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \ + registerDependenciesFn(this->getOutBuf(), this->getResult()); \ + } + +POPULATE_DEPENDENCIES(PimVMaxOp) +POPULATE_DEPENDENCIES(PimApplyFiltersOp) +POPULATE_DEPENDENCIES(PimSumOp) +POPULATE_DEPENDENCIES(PimVSDivOp) +POPULATE_DEPENDENCIES(PimVReluOp) +POPULATE_DEPENDENCIES(PimVExpOp) + +} // namespace pim +} // namespace onnx_mlir + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" diff --git a/src/PIM/Dialect/PIM/PimOps.hpp b/src/PIM/Dialect/PIM/PimOps.hpp new file mode 100644 index 0000000..24b6711 --- /dev/null +++ b/src/PIM/Dialect/PIM/PimOps.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +/// Include the auto-generated header files containing the declarations +#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc" + +#define GET_OP_CLASSES +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc" diff --git a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp new file mode 100644 index 0000000..dc09326 --- /dev/null +++ b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp @@ -0,0 +1,172 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp" + +using namespace mlir; +using namespace bufferization; + +namespace onnx_mlir { +namespace pim { + +struct MemCopyHostToDevOpInterface +: DstBufferizableOpInterfaceExternalModel { + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto memCopyHostToDevOp = cast(op); + auto deviceDst = memCopyHostToDevOp.getDeviceDst(); + auto hostSrc = memCopyHostToDevOp.getHostSrc(); + + auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state); + if (failed(deviceDstOpt)) + return failure(); + auto deviceDstMemRef = *deviceDstOpt; + + auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state); + if (failed(hostSrcOpt)) + return failure(); + auto hostSrcMemRef = *hostSrcOpt; + + replaceOpWithNewBufferizedOp(rewriter, + memCopyHostToDevOp, + deviceDstMemRef.getType(), + deviceDstMemRef, + hostSrcMemRef, + memCopyHostToDevOp.getDeviceDstOffsetAttr(), + memCopyHostToDevOp.getHostSrcOffsetAttr(), + memCopyHostToDevOp.getSizeAttr()); + return success(); + } +}; + +struct MemCopyDevToHostOpInterface +: DstBufferizableOpInterfaceExternalModel { + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto memCopyDevToHostOp = cast(op); + + auto globalDst = memCopyDevToHostOp.getHostDst(); + auto globalDstOpt = getBuffer(rewriter, globalDst, options, state); + if (failed(globalDstOpt)) + return failure(); + auto globalDstMemRef = *globalDstOpt; + + auto localSrc = memCopyDevToHostOp.getDeviceSrc(); + auto localSrcOpt = getBuffer(rewriter, localSrc, options, state); + if (failed(localSrcOpt)) + return failure(); + auto localSrcMemRef = *localSrcOpt; + + replaceOpWithNewBufferizedOp(rewriter, + memCopyDevToHostOp, + globalDstMemRef.getType(), + globalDstMemRef, + localSrcMemRef, + memCopyDevToHostOp.getHostDstOffsetAttr(), + memCopyDevToHostOp.getDeviceSrcOffsetAttr(), + memCopyDevToHostOp.getSizeAttr()); + return success(); + } +}; + +struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const { + auto vmmOp = cast(op); + Value readVal = uRead->get(); + Value writeVal = uWrite->get(); + if (writeVal != vmmOp.getOutBuf()) + return false; + if (readVal == vmmOp.getVectorInput()) + if (state.areEquivalentBufferizedValues(readVal, writeVal)) + return true; + return false; + } + + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto vmmOp = cast(op); + + auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state); + if (failed(vectorInputOpt)) + return failure(); + + auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state); + if (failed(outBufOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt); + return success(); + } +}; + +struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto mvmOp = cast(op); + + auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state); + if (failed(vectorInputOpt)) + return failure(); + + auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state); + if (failed(outBufOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt); + return success(); + } +}; + +struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val, + const AnalysisState& state, + ArrayRef opOperands) const { + return true; + } + + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto vaddOp = cast(op); + + auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state); + if (failed(aOpt)) + return failure(); + + auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state); + if (failed(bOpt)) + return failure(); + + auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state); + if (failed(outBufOpt)) + return failure(); + + replaceOpWithNewBufferizedOp(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt); + return success(); + } +}; + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { + registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { + PimMemCopyHostToDevOp::attachInterface(*ctx); + PimMemCopyDevToHostOp::attachInterface(*ctx); + PimVMMOp::attachInterface(*ctx); + PimMVMOp::attachInterface(*ctx); + PimVAddOp::attachInterface(*ctx); + }); +} + +} // namespace pim +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp new file mode 100644 index 0000000..aa9e47c --- /dev/null +++ b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "mlir/IR/DialectRegistry.h" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace pim { + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace pim +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt new file mode 100644 index 0000000..2ff4c97 --- /dev/null +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -0,0 +1,15 @@ +add_onnx_mlir_dialect(Spatial spat) +add_onnx_mlir_dialect_doc(spat Spatial.td) + + +add_onnx_mlir_library(SpatialOps + SpatialOps.cpp + Transforms/SpatialBufferizableOpInterface.cpp + + DEPENDS + OMSpatialIncGen + + LINK_LIBS PUBLIC + MLIRIR + OMMlirDialects +) \ No newline at end of file diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td new file mode 100644 index 0000000..5397ec4 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -0,0 +1,355 @@ +#ifndef SPATIAL_DIALECT_H +#define SPATIAL_DIALECT_H + +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/AttrTypeBase.td" + +def SpatialDialect : Dialect { + let name = "spat"; + let summary = "Dialect designed for deep learning computation in a spatial architecture"; + let cppNamespace = "::onnx_mlir::spatial"; + let useDefaultTypePrinterParser = 1; +} + +class SpatOp traits = []> : + Op; + +// TODO maybe remove and use AnyRankedTensor directly +def SpatTensor: + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + +class SpatType traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def SpatChannelType : SpatType<"SpatChannel", "ch"> { + let summary = "Virtual channel type"; +} + +def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { + let summary = "Compute operation, with constant weights already attached"; + + let arguments = (ins + Variadic:$weights, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body + }]; +} + +def SpatYieldOp: SpatOp<"yield", [Terminator]> { + let arguments = (ins + Variadic:$outputs + ); + + let assemblyFormat = [{ + $outputs attr-dict `:` type($outputs) + }]; +} + +//===----------------------------------------------------------------------===// +// Data movement operations +//===----------------------------------------------------------------------===// + +def SpatChannelNewOp: SpatOp<"channel_new", []> { + let results = (outs + SpatChannelType:$new_channel + ); + + let builders = [ + OpBuilder<(ins ), [{ + $_state.addTypes(SpatChannelType()); + }]> + ]; + + let assemblyFormat = [{ + attr-dict + }]; +} + +def SpatChannelSendOp: SpatOp<"channel_send", []> { + let arguments = (ins + SpatChannelType: $channel, + SpatTensor: $data + ); + + let assemblyFormat = [{ + $data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)` + }]; +} + +def SpatChannelReceiveOp: SpatOp<"channel_receive", []> { + let arguments = (ins + SpatChannelType: $channel + ); + + let results = (outs + SpatTensor: $data + ); + + let assemblyFormat = [{ + $channel attr-dict `:` `(` type($channel) `->` type($data) `)` + }]; +} + +def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> { + let arguments = (ins + SpatChannelType: $channel, + SpatTensor: $data + ); +} + +def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> { + let arguments = (ins + SpatChannelType: $channel + ); + + let results = (outs + SpatTensor: $data + ); +} + +//===----------------------------------------------------------------------===// +// Math operations +//===----------------------------------------------------------------------===// + +def SpatConstantOp: SpatOp<"constant", []> { + let description = [{ + "Constant value, should be used for weights and biases" + }]; + + let arguments = (ins + AnyAttr: $value, + BoolAttr: $shouldAllocate + ); + + let results = (outs + SpatTensor: $out + ); +} + +def SpatWeightedVMMOp: SpatOp<"Wvmm", []> { + let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute."; + + let arguments = (ins + I32Attr: $weightIndex, + SpatTensor:$vector + ); + + let results = (outs + SpatTensor:$output + ); + + // TODO: Verifier that checks it is within a WeightedCompute operation, + // that the weightIndex is valid, and that the matrix is of the right size. + let hasVerifier = 1; +} + +def SpatWeightedMVMOp: SpatOp<"Wmvm", []> { + let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute."; + + let arguments = (ins + I32Attr: $weightIndex, + SpatTensor:$vector + ); + + let results = (outs + SpatTensor:$output + ); + + // TODO: Verifier that checks it is within a WeightedCompute operation, + // that the weightIndex is valid, and that the matrix is of the right size. + let hasVerifier = 1; +} + + +def SpatVAddOp: SpatOp<"vadd", []> { + let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1"; + + let arguments = (ins + SpatTensor: $a, + SpatTensor: $b + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output) + }]; +} + +def SpatVMulOp: SpatOp<"vmul", []> { + let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1"; + + let arguments = (ins + SpatTensor: $a, + SpatTensor: $b + ); + + let results = (outs + SpatTensor:$output + ); + + //let hasVerifier = 1; + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output) + }]; +} + +def SpatVDivOp: SpatOp<"vdiv", []> { + let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1"; + + let arguments = (ins + SpatTensor:$a, + SpatTensor:$b + ); + + let results = (outs + SpatTensor:$output + ); + + //let hasVerifier = 1; + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output) + }]; +} + +//TODO: remove +def SpatVSDivOp: SpatOp<"vsdiv", []> { + + let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)"; + + let arguments = (ins + SpatTensor:$dividend, + SpatTensor:$divisor + ); + + let results = (outs + SpatTensor:$output + ); +} + +def SpatSumOp: SpatOp<"sum", []> { + let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience"; + + let arguments = (ins + SpatTensor: $input + ); + + let results = (outs + SpatTensor:$output + ); +} + +def SpatSigmoidOp: SpatOp<"sigmoid", []> { + let arguments = (ins + SpatTensor:$input + ); + + let results = (outs + SpatTensor:$output + ); +} + +def SpatReluOp: SpatOp<"relu", []> { + let arguments = (ins + SpatTensor:$input + ); + + let results = (outs + SpatTensor:$output + ); +} + +def SpatVMaxOp: SpatOp<"vmax", []> { + + let summary = "Element-wise max function"; + + let arguments = (ins + SpatTensor: $a, + SpatTensor: $b + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; +} + +def SpatApplyFiltersOp : SpatOp<"apply_filters", []> { + let summary = "Apply multiple crossbar weights to a convolutional input tile."; + let description = [{ + Applies a variable number of crossbar weights to a single large image tensor tile, + producing a corresponding output tile. This essentially encapsulates a big for loop + over all pixels in the input tile, where each pixel is multiplied by all the weights + in the operation. + }]; + + let arguments = (ins + I64ArrayAttr: $weightIndices, + I64ArrayAttr: $xKernelPositions, + I64ArrayAttr: $yKernelPositions, + SpatTensor: $input + ); + let results = (outs SpatTensor); + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type(results) + }]; +} + +//===----------------------------------------------------------------------===// +// Other operations +//===----------------------------------------------------------------------===// + +def SpatImgConcatOp: SpatOp<"img_concat", []> { + + let summary = "Concatenate pixel tiles into a single image"; + + let description = [{ + Concatenate pixel tiles into a single image: + 1. First, concatenate the pixel tiles along the "channel" axis (axis 1). + 2. Next, concatenate the pixel tiles along the "width" axis (axis 2). + 3. Finally, concatenate the pixel tiles along the "height" axis (axis 3). + + The input tiles should be provided in a specific order: + start from the top left pixel, + then continue with the pixel on its right, + and once you finish the first row of pixels, go to the next row. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Value getInputTile(size_t x, size_t y, size_t tile); + }]; +} + +#endif // SPATIAL_DIALECT_H \ No newline at end of file diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp new file mode 100644 index 0000000..35e0176 --- /dev/null +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -0,0 +1,339 @@ +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/LogicalResult.h" + +#include + +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { + +void SpatialDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc" + + >(); + addOperations< +#define GET_OP_LIST +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" + + >(); +} + +inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, + ArrayRef& matrixShape, + ArrayRef& vectorShape, + ArrayRef& outputShape) { + + // Verify that the matrix, vector and output shapes have rank 2 + if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) + return emitter->emitError("matrix, vector and output must have rank 2"); + + // Verify that the matrix shape is (N, M) + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + if (N <= 0 || M <= 0) + return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); + + // Verify that the vector shape is (M, 1) + int64_t vectorM = vectorShape[0]; + int64_t vector1 = vectorShape[1]; + if (vectorM != M || vector1 != 1) + return emitter->emitError("vector shape must be (M, 1)"); + + // Verify that the output shape is (N, 1) + int64_t outputN = outputShape[0]; + int64_t output1 = outputShape[1]; + if (outputN != N || output1 != 1) + return emitter->emitError("output shape must be (N, 1)"); + + return success(); +} + +inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, + ArrayRef& matrixShape, + ArrayRef& vectorShape, + ArrayRef& outputShape) { + + // Verify that the matrix, vector and output shapes have rank 4 + if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) + return emitter->emitError("matrix, vector and output must have rank 4"); + + // Verify that the matrix shape is (N, M, 1, 1) + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + int64_t matrix1First = matrixShape[2]; + int64_t matrix1Second = matrixShape[3]; + if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) + return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); + + // Verify that the vector shape is (1, M, 1, 1) + int64_t vector1First = vectorShape[0]; + int64_t vectorM = vectorShape[1]; + int64_t vector1Second = vectorShape[2]; + int64_t vector1Third = vectorShape[3]; + if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { + if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { + // This is ok, it was caused by the simplification of the concat error + } + else { + return emitter->emitError("vector shape must be (1, M, 1, 1)"); + } + } + + // Verify that the output shape is (1, N, 1, 1) + int64_t output1First = outputShape[0]; + int64_t outputN = outputShape[1]; + int64_t output1Second = outputShape[2]; + int64_t output1Third = outputShape[3]; + if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) + return emitter->emitError("output shape must be (1, N, 1, 1)"); + + return success(); +} + +llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { + auto wcomputeOp = dyn_cast(weigthedOp->getParentOp()); + if (wcomputeOp) + return cast(wcomputeOp.getWeights()[weightIndex].getType()).getShape(); + + auto coreOp = dyn_cast(weigthedOp->getParentOp()); + + if (coreOp) + return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); + + return failure(); +} + +LogicalResult SpatWeightedMVMOp::verify() { + auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); + if (failed(matrixShapeOpt)) + return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op"); + auto matrixShape = *matrixShapeOpt; + auto vectorShape = getVector().getType().getShape(); + auto outputShape = getOutput().getType().getShape(); + + /* Two possible accepted shapes: + 1. matrix: (N, M); vector: (M, 1); output: (N, 1) + 2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1) + */ + + if (matrixShape.size() == 2) + return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); + else if (matrixShape.size() == 4) + return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); + else + return emitError("matrix rank must be 2 or 4"); +} + +LogicalResult SpatWeightedVMMOp::verify() { + auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); + if (failed(matrixShapeOpt)) + return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op"); + auto matrixShape = *matrixShapeOpt; + auto vectorShape = getVector().getType().getShape(); + auto outputShape = getOutput().getType().getShape(); + + /* Accepted shape: + 1. vector: (1, N); matrix: (N, M); output: (1, M) + */ + if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) + return emitError("matrix, vector and output must have rank 2"); + + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + if (N <= 0 || M <= 0) + return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); + + int64_t vector1 = vectorShape[0]; + int64_t vectorN = vectorShape[1]; + if (vectorN != N || vector1 != 1) + return emitError("vector shape must be (N, 1)"); + + int64_t output1 = outputShape[0]; + int64_t outputM = outputShape[1]; + if (outputM != M || output1 != 1) + return emitError("output shape must be (M, 1)"); + + return success(); +} + +LogicalResult SpatVAddOp::verify() { + // At least two operands + if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) + return failure(); + + return OpTrait::impl::verifySameOperandsAndResultType(*this); +} + +LogicalResult SpatVMaxOp::verify() { + // At least two operands + if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) + return failure(); + + return OpTrait::impl::verifySameOperandsAndResultType(*this); +} + +LogicalResult SpatImgConcatOp::verify() { + auto imgShape = mlir::cast(getType()); + size_t img_w = GET_IMAGE_WIDTH(imgShape); + size_t img_h = GET_IMAGE_HEIGHT(imgShape); + size_t img_c = GET_IMAGE_CHANNEL(imgShape); + + size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); + size_t channelTileRest = img_c % crossbarSize; + + auto operands = getOperands(); + + // Check number of operands + if (img_w * img_h * channelTiles != operands.size()) + return emitError("Number of operands does not match output image size"); + + // For each output pixel, check that the inputTiles have a correct shape + for (size_t x = 0; x < img_w; x++) { + for (size_t y = 0; y < img_h; y++) { + size_t channel_counts = 0; + for (size_t t = 0; t < channelTiles; t++) { + auto inputShape = mlir::cast(getInputTile(x, y, t).getType()); + if (!inputShape) + return emitError("Invalid input type, must be ShapedType"); + + // N == W == H == 1 + if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1) + return emitError("Invalid input shape: N,W,H must all be 1"); + + size_t inputChannels = GET_IMAGE_CHANNEL(inputShape); + + // Check the number of channels in this tile are correct: + // - CASE1: last tile of pixel, if there is some rest it must match that + // - CASE2: common case, the channel count is exactly the crossbarSize + if (t == channelTiles - 1 && channelTileRest != 0) { + if (inputChannels != channelTileRest) + return emitError("Invalid channel count for last tile of pixel"); + } + else { + if (inputChannels != crossbarSize) + return emitError("Invalid channel count for some pixel tile"); + } + + channel_counts += inputChannels; + } + + if (channel_counts != img_c) + emitError("Invalid number of channels for some pixel"); + } + } + + return success(); +} + +LogicalResult SpatWeightedCompute::verify() { + // Check that it has a terminator, it is a yieldOp, and it has a single + // operand with the same type as the result + auto& block = getBody().front(); + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return emitError("ComputeOp must have a single yield operation"); + + auto resultTypes = getResultTypes(); + auto yieldTypes = yieldOp->getOperandTypes(); + if (resultTypes.size() != yieldTypes.size()) { + return emitError("ComputeOp must have same number of results as yieldOp " + "operands"); + } + + for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { + auto resultType = std::get<0>(it); + auto yieldType = std::get<1>(it); + + // Same type and compatible shape + if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) { + return emitError("ComputeOp output must be of the same type as yieldOp " + "operand"); + } + + // Same encoding + if (auto resultRankedType = dyn_cast(resultType)) { + if (auto yieldRankedType = dyn_cast(yieldType)) { + if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) { + return emitError("ComputeOp output must have the same encoding as " + "yieldOp operand"); + } + } + else { + return emitError("ComputeOp output has an encoding while yieldOp " + "operand does not have one"); + } + } + else { + // If result does not have an encoding, yield shouldn't either + if (auto yieldRankedType = dyn_cast(yieldType)) { + return emitError("ComputeOp output must not have an encoding if " + "yieldOp operand has one"); + } + } + } + + // Check that each block argument is used + for (auto arg : block.getArguments()) + if (arg.use_empty()) + return emitError("ComputeOp block argument is not used"); + + return success(); +} + +Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) { + auto operands = getOperands(); + auto imgShape = mlir::cast(getType()); + size_t img_w = GET_IMAGE_WIDTH(imgShape); + size_t img_h = GET_IMAGE_HEIGHT(imgShape); + size_t img_c = GET_IMAGE_CHANNEL(imgShape); + + size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); + + assert(tile < channelTiles); + assert(x < img_w); + assert(y < img_h); + + return operands[tile + x * channelTiles + y * img_w * channelTiles]; +} + +} // namespace spatial +} // namespace onnx_mlir + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc" diff --git a/src/PIM/Dialect/Spatial/SpatialOps.hpp b/src/PIM/Dialect/Spatial/SpatialOps.hpp new file mode 100644 index 0000000..2aca69e --- /dev/null +++ b/src/PIM/Dialect/Spatial/SpatialOps.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Types.h" + +/// Include the auto-generated header files containing the declarations +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.hpp.inc" + +#define GET_OP_CLASSES +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc" diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp new file mode 100644 index 0000000..9daeadc --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -0,0 +1,493 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" + +#include + +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +using namespace bufferization; + +namespace onnx_mlir { +namespace spatial { + +memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) { + auto resultShape = cast(resultType); + auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType()); + + // Alloc an output memref + return rewriter.create(loc, memrefResultType); +} + +const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); + +llvm::FailureOr getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { + + // This function requires the existence of ChannelNewOp and the other + // Receive/Send operation. However, during bufferization, the first of the + // Receive/Send operation that is processed gets removed. As such, we need to + // "precompute" the coreId needed for the other op, and save it as attribute + auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME); + if (precomputedOtherCoreId) + return cast(precomputedOtherCoreId).getInt(); + + auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter); + if (failed(notOpUserOpt)) + return failure(); + Operation* notOpUser = *notOpUserOpt; + + // Save the coreId for this op into the other op as attribute + auto opCoreIdAttr = cast(op->getParentOp()).getCoreIdAttr(); + notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr); + + return cast(notOpUser->getParentOp()).getCoreId(); +} + +struct WComputeOpInterface : BufferizableOpInterface::ExternalModel { + + // Input tensor to the compute OP are always read into its local memory + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + // Input tensor to the compute OP are _never_ written into its local memory + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // In general, no tensor is aliased with any other tensor in the compute OP + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + // TODO: Is it an empty list or a list of "UNKNOWN" values? + return {}; + } + + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + // Bufferize its block + + auto& block = op->getRegion(0).front(); + + return bufferizeBlockSignature(&block, rewriter, options, state); + } +}; + +/* + * This can be used for operation that have a single argument, which is a + * variadic of tensors, and a single output with the same same shape + * Example: VAdd, VSub, VExp + */ +template +struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel { + + // Input tensors to the OP are always read + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + // Input tensors to the OP are _never_ written + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // In general, no tensor is aliased with any other tensor in the OP + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + // Cast tensor values into memref values + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + + // Turn Tensor Operands into Memref Operands + SmallVector memrefOperands; + memrefOperands.reserve(op->getNumOperands()); + for (auto operand : op->getOperands()) { + auto memref = getBuffer(rewriter, operand, options, state); + if (failed(memref)) + return failure(); + memrefOperands.push_back(*memref); + } + + // TODO: Support addiction with more than 2 operands + if (memrefOperands.size() > 2) { + op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs " + "with 1 or 2 operands, for now."); + return failure(); + } + + // Alloc an output memref + Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); + + memrefOperands.push_back(outputTensor); + + Value newValue = rewriter.create(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes(); + + replaceOpWithBufferizedValues(rewriter, op, newValue); + + return success(); + } +}; + +template +struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel { + + // Input tensors to the OP are always read + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + // Input tensors to the OP are _never_ written + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // In general, no tensor is aliased with any other tensor in the OP + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + // Cast tensor value into memref value + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state); + if (failed(memrefOperandOpt)) + return failure(); + auto memrefOperand = *memrefOperandOpt; + + // Alloc an output memref + Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); + + Value newValue = + rewriter + .create( + op->getLoc(), outputTensor.getType(), cast(op).getWeightIndexAttr(), memrefOperand, outputTensor) + .getOutRes(); + + replaceOpWithBufferizedValues(rewriter, op, newValue); + + return success(); + } +}; + +struct ChannelReceiveOpInterface +: BufferizableOpInterface::ExternalModel { + + // Input value is the channel (not read/written, its more of an attribute) + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // See above + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // See above + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + // TODO: Is it an empty list or a list of "UNKNOWN" values? + return {}; + } + + /* + * Turn the channel receive to pim.recv + */ + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + + auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); + + auto numElements = cast(outputTensor.getType()).getNumElements(); + auto elementSize = cast(outputTensor.getType()).getElementTypeBitWidth() / 8; + + auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter); + if (failed(srcCoreId)) + return failure(); + + Value newValue = rewriter + .create(op->getLoc(), + outputTensor.getType(), + outputTensor, + rewriter.getI32IntegerAttr(numElements * elementSize), + rewriter.getI32IntegerAttr(srcCoreId.value())) + .getOut(); + + replaceOpWithBufferizedValues(rewriter, op, newValue); + + return success(); + } +}; + +struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel { + + // First input is channel (not read/writter) second input is Tensor to send, + // which is read + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return opOperand.getOperandNumber() == 2; + } + + // See above (both non-written) + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // See above + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + // TODO: Is it an empty list or a list of "UNKNOWN" values? + return {}; + } + + /* + * Turn the channel send to pim.send + */ + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto srcTensor = op->getOperand(1); + + auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); + if (failed(srcTensorOpt)) + return failure(); + auto srcMemRef = *srcTensorOpt; + + auto numElements = cast(srcTensor.getType()).getNumElements(); + auto elementSize = cast(srcTensor.getType()).getElementTypeBitWidth() / 8; + + auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter); + if (failed(dstCoreId)) + return failure(); + + replaceOpWithNewBufferizedOp(rewriter, + op, + srcMemRef, + rewriter.getI32IntegerAttr(numElements * elementSize), + rewriter.getI32IntegerAttr(dstCoreId.value())); + + return success(); + } +}; + +struct ChannelBroadcastReceiveOpInterface +: BufferizableOpInterface::ExternalModel { + + // Input value is the channel (not read/written, its more of an attribute) + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // See above + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // See above + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + // TODO: Is it an empty list or a list of "UNKNOWN" values? + return {}; + } + + /* + * Turn the channel receive to pim.load using by creating a new global buffer + */ + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + + auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); + + auto outputSize = cast(outputTensor.getType()).getNumElements(); + + auto channelNewOp = op->getOperand(0).getDefiningOp(); + if (!channelNewOp) { + op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand"); + return failure(); + } + + // The first 'broadcast' operation creates the buffer just after the + // channelNewOp, while the other 'broadcast' operation need to find this + // buffer allocation just after the channelNewOp + Value bufferAllocation; + if (auto allocOpAfterChannel = dyn_cast(channelNewOp->getNextNode())) { + // Buffer already allocated, load from this buffer + bufferAllocation = allocOpAfterChannel; + } + else { + // Buffer was not allocated previously, allocate it after channelNewOp + rewriter.setInsertionPointAfter(channelNewOp); + bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); + } + + rewriter.setInsertionPoint(op); + auto memCopyHostToDevOp = rewriter.create(op->getLoc(), + outputTensor.getType(), + outputTensor, + bufferAllocation, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(outputSize)); + + replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst()); + + return success(); + } +}; + +struct ChannelBroadcastSendOpInterface +: BufferizableOpInterface::ExternalModel { + + // First input is channel (not read/writter) second input is Tensor to send, + // which is read + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return opOperand.getOperandNumber() == 2; + } + + // See above (both non-written) + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + // See above + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + // TODO: Is it an empty list or a list of "UNKNOWN" values? + return {}; + } + + /* + * Turn the channel send to pim.send + */ + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + auto srcTensor = op->getOperand(1); + + auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); + if (failed(srcTensorOpt)) + return failure(); + auto srcMemRef = *srcTensorOpt; + + auto channelNewOp = op->getOperand(0).getDefiningOp(); + if (!channelNewOp) { + op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand"); + return failure(); + } + + // The first 'broadcast' operation creates the buffer just after the + // channelNewOp, while the other 'broadcast' operation need to find this + // buffer allocation just after the channelNewOp + Value bufferAllocation; + if (auto allocOpAfterChannel = dyn_cast(channelNewOp->getNextNode())) { + // Buffer already allocated, load from this buffer + bufferAllocation = allocOpAfterChannel; + } + else { + // Buffer was not allocated previously, allocate it after channelNewOp + rewriter.setInsertionPointAfter(channelNewOp); + bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter); + } + + rewriter.setInsertionPoint(op); + replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef}); + return success(); + } +}; + +struct VAddOpInterfaceFromTemplate +: VariadicArgumentElementWiseOpInterface {}; + +struct WVMMOpInterface : WeightedMultiplicationsOpInterface {}; + +struct WMVMOpInterface : WeightedMultiplicationsOpInterface {}; + +struct SumOpInterface : VariadicArgumentElementWiseOpInterface {}; + +struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface {}; + +struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface {}; + +// Create a new bufferizable op interface for the apply filters operation. +struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel { + + // One operand ($input) is read from. All other inputs are only written to. + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + + // Operand 0: $input + // Operand 1: $outBuf + // Operand 2: $accumBuf + return opOperand.getOperandNumber() == 0; + } + + // One input ($accumBuf) is written to. All other inputs are only read. + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + + // Operand 0: $input + // Operand 1: $outBuf + // Operand 2: $accumBuf + return opOperand.getOperandNumber() == 2; + } + + // No operands are aliased with any other operands. + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + // Bufferize the operation. + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + + // Get the input tensor buffer. + auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state); + + if (failed(inputBuffer)) + return failure(); + + // Create a new buffer for the output tensor. + auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); + + // Create a new buffer for the accumulation buffer. + // To do this, create a new allocation operation. Size must be axbx1x1, + // where axbxcxd is the size of the output tensor. Since the shape is + // different, we can't immediately use createEmptyFromType, we first need to + // create the shape of the accumulation buffer. + auto accumShape = llvm::to_vector<4>(cast(op->getResult(0).getType()).getShape()); + + // Set the last two dimensions to 1. + accumShape[accumShape.size() - 1] = 1; + accumShape[accumShape.size() - 2] = 1; + + auto accumType = MemRefType::get(accumShape, cast(op->getResult(0).getType()).getElementType()); + + auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter); + + // Bufferize the operation. + auto weightIndices = cast(op).getWeightIndicesAttr(); + auto xKernelPositions = cast(op).getXKernelPositionsAttr(); + auto yKernelPositions = cast(op).getYKernelPositionsAttr(); + + Value bufferized = rewriter.create(op->getLoc(), + outputTensor.getType(), + weightIndices, + xKernelPositions, + yKernelPositions, + *inputBuffer, + outputTensor, + accumBuffer); + + // Replace the operation with the bufferized value. + replaceOpWithBufferizedValues(rewriter, op, bufferized); + + return success(); + } +}; + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { + registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) { + SpatWeightedCompute::attachInterface(*ctx); + SpatVAddOp::attachInterface(*ctx); + SpatWeightedVMMOp::attachInterface(*ctx); + SpatWeightedMVMOp::attachInterface(*ctx); + SpatSumOp::attachInterface(*ctx); + SpatVSDivOp::attachInterface(*ctx); + SpatVMaxOp::attachInterface(*ctx); + SpatChannelReceiveOp::attachInterface(*ctx); + SpatChannelSendOp::attachInterface(*ctx); + SpatChannelBroadcastReceiveOp::attachInterface(*ctx); + SpatChannelBroadcastSendOp::attachInterface(*ctx); + SpatApplyFiltersOp::attachInterface(*ctx); + }); +} + +struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface {}; + +struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface {}; + +void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { + registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) { + ONNXReluOp::attachInterface(*ctx); + ONNXExpOp::attachInterface(*ctx); + }); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp new file mode 100644 index 0000000..455bbe7 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "mlir/IR/DialectRegistry.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); + +void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry); + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Pass/CountInstructionPass.cpp b/src/PIM/Pass/CountInstructionPass.cpp new file mode 100644 index 0000000..edb72e2 --- /dev/null +++ b/src/PIM/Pass/CountInstructionPass.cpp @@ -0,0 +1,67 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Compiler/CompilerUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct CountInstructionPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass) + + StringRef getArgument() const override { return "count-instruction-pass"; } + + StringRef getDescription() const override { + return "Count instructions for each core/compute in the module"; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + CountInstructionPass() {} + CountInstructionPass(const CountInstructionPass &pass) + : PassWrapper>() {} + void runOnOperation() final { + ModuleOp module = getOperation(); + + func::FuncOp func = *module.getOps().begin(); + + unsigned totalInstructionCount = 0; + + unsigned computeId = 0; + for (auto computeOp : func.getOps()) { + unsigned instructionCount = 0; + instructionCount += computeOp.getBody().front().getOperations().size(); + llvm::outs() << "Compute " << computeId << ": " << instructionCount + << " instructions\n"; + totalInstructionCount += instructionCount; + computeId++; + } + + unsigned coreId = 0; + for (auto coreOp : func.getOps()) { + unsigned instructionCount = 0; + instructionCount += coreOp.getBody().front().getOperations().size(); + llvm::outs() << "Core " << coreId << ": " << instructionCount + << " instructions\n"; + totalInstructionCount += instructionCount; + coreId++; + } + + llvm::outs() << "Total instruction count: " << totalInstructionCount + << "\n"; + } +}; + +} // namespace + +std::unique_ptr createCountInstructionPass() { + return std::make_unique(); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Pass/MessagePass.cpp b/src/PIM/Pass/MessagePass.cpp new file mode 100644 index 0000000..1fd61ce --- /dev/null +++ b/src/PIM/Pass/MessagePass.cpp @@ -0,0 +1,37 @@ +#include "mlir/Pass/Pass.h" +#include "src/Compiler/CompilerUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct MessagePass : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass) + + StringRef getArgument() const override { return "message-pass"; } + + StringRef getDescription() const override { + return "Lower ONNX ops to Spatial ops."; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + MessagePass(std::string message) : message(message) {} + MessagePass(const MessagePass &pass) + : PassWrapper>() {} + void runOnOperation() final { showCompilePhase(message); } + +private: + std::string message; +}; + +} // namespace + +std::unique_ptr createMessagePass(std::string message) { + return std::make_unique(message); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/Pass/PimPasses.hpp b/src/PIM/Pass/PimPasses.hpp new file mode 100644 index 0000000..9b1c0f1 --- /dev/null +++ b/src/PIM/Pass/PimPasses.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "mlir/Pass/Pass.h" +#include + +using namespace mlir; + +namespace onnx_mlir { + +std::unique_ptr createONNXToSpatialPass(); + +std::unique_ptr createSpatialToGraphvizPass(); + +std::unique_ptr createSpatialToPIMPass(); + +std::unique_ptr createBufferizePimPass(); + +std::unique_ptr createMessagePass(std::string message); + +std::unique_ptr createCountInstructionPass(); + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp new file mode 100644 index 0000000..8c5bcd4 --- /dev/null +++ b/src/PIM/PimAccelerator.cpp @@ -0,0 +1,110 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Accelerators/PIM/PimAccelerator.hpp" + +#include + +#define DEBUG_TYPE "PimAccelerator" + +namespace onnx_mlir { +namespace accel { + +Accelerator *createPIM() { return PimAccelerator::getInstance(); } + +PimAccelerator *PimAccelerator::instance = nullptr; + +PimAccelerator *PimAccelerator::getInstance() { + if (instance == nullptr) + instance = new PimAccelerator(); + return instance; +} + +PimAccelerator::PimAccelerator() : Accelerator(Accelerator::Kind::PIM) { + LLVM_DEBUG(llvm::dbgs() << "Creating a PIM accelerator\n"); + acceleratorTargets.push_back(this); +}; + +PimAccelerator::~PimAccelerator() { delete instance; } + +uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; } + +void PimAccelerator::addPasses(mlir::OwningOpRef &module, + mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, + std::string outputNameNoExt) const { + LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n"); + addPassesPim(module, pm, emissionTarget, outputNameNoExt); +} + +void PimAccelerator::registerDialects(mlir::DialectRegistry ®istry) const { + LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n"); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + spatial::registerBufferizableOpInterfaceExternalModels(registry); + spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); + pim::registerBufferizableOpInterfaceExternalModels(registry); +} + +void PimAccelerator::registerPasses(int optLevel) const { + LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); + // Register here all the passes that could be used + mlir::registerPass(createONNXToSpatialPass); + mlir::registerPass(createSpatialToGraphvizPass); + mlir::registerPass(createSpatialToPIMPass); + mlir::registerPass(createBufferizePimPass); +} + +void PimAccelerator::configurePasses() const { + LLVM_DEBUG(llvm::dbgs() << "Configuring passes for PIM accelerator\n"); + // TODO: This does nothing for now. +} + +mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType( + const mlir::TensorType tensorType) const { + // Do not convert tensor types to memref types. + return nullptr; +} + +void PimAccelerator::conversionTargetONNXToKrnl( + mlir::ConversionTarget &target) const { + target.addLegalDialect(); +} + +void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns, + mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const { + // TODO: Add patterns for conversion +} + +void PimAccelerator::conversionTargetKrnlToLLVM( + mlir::ConversionTarget &target) const {} + +void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns, + mlir::LLVMTypeConverter &typeConverter, mlir::MLIRContext *ctx) const { + // We should not need this, since we offload it all to PIM. +} + +} // namespace accel +} // namespace onnx_mlir diff --git a/src/PIM/PimAccelerator.hpp b/src/PIM/PimAccelerator.hpp new file mode 100644 index 0000000..f96cb60 --- /dev/null +++ b/src/PIM/PimAccelerator.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include "mlir/IR/BuiltinTypes.h" +#include "src/Accelerators/Accelerator.hpp" + +namespace onnx_mlir { +namespace accel { + +/// Singleton class to construct PIM accelerator. +class PimAccelerator final : public Accelerator { +private: + static PimAccelerator *instance; + PimAccelerator(); + +public: + /// Singleton should not be clonable or assignable. + PimAccelerator(PimAccelerator &) = delete; + void operator=(const PimAccelerator &) = delete; + + ~PimAccelerator(); + + /// Creates an instance on the first invocation. Subsequent invocations + /// return the existing instance. + static PimAccelerator *getInstance(); + + /// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc. + static bool classof(const Accelerator *accel) { + return accel->getKind() == Accelerator::Kind::PIM; + } + static bool classof(const PimAccelerator *) { return true; } + + uint64_t getVersionNumber() const final; + + //===--------------------------------------------------------------------===// + // Hooks for onnx-mlir-opt driver + //===--------------------------------------------------------------------===// + virtual void addPasses(mlir::OwningOpRef &module, + mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, + std::string outputNameNoExt) const final; + //===--------------------------------------------------------------------===// + // Hooks for onnx-mlir-opt driver + //===--------------------------------------------------------------------===// + virtual void registerDialects(mlir::DialectRegistry ®istry) const final; + virtual void registerPasses(int optLevel) const final; + //===--------------------------------------------------------------------===// + // Hooks for both onnx-mlir and onnx-mlir-opt drivers + //===--------------------------------------------------------------------===// + virtual void configurePasses() const final; + //===--------------------------------------------------------------------===// + // Hooks for onnx-to-krnl pass + //===--------------------------------------------------------------------===// + virtual mlir::MemRefType convertTensorTypeToMemRefType( + const mlir::TensorType tensorType) const final; + virtual void conversionTargetONNXToKrnl( + mlir::ConversionTarget &target) const final; + virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns, + mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final; + + //===--------------------------------------------------------------------===// + // Hooks for krnl-to-llvm pass + //===--------------------------------------------------------------------===// + virtual void conversionTargetKrnlToLLVM( + mlir::ConversionTarget &target) const final; + virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns, + mlir::LLVMTypeConverter &typeConverter, + mlir::MLIRContext *ctx) const final; +}; + +} // namespace accel +} // namespace onnx_mlir diff --git a/src/PIM/Transforms/PimBufferizationPass.cpp b/src/PIM/Transforms/PimBufferizationPass.cpp new file mode 100644 index 0000000..054d312 --- /dev/null +++ b/src/PIM/Transforms/PimBufferizationPass.cpp @@ -0,0 +1,87 @@ +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/Support/raw_os_ostream.h" + +#include +#include + +#include "Compiler/PimCodeGen.hpp" +#include "PimBufferizationPass.hpp" +#include "src/Compiler/CompilerOptions.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace pim { + +void PimBufferizationPass::runOnOperation() { + auto moduleOp = getOperation(); + + // Do One-Shot-Bufferization + bufferization::OneShotBufferizationOptions options; + options.allowUnknownOps = true; + bufferization::BufferizationState state; + if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { + moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); + signalPassFailure(); + } + + // Remove toTensor operations + moduleOp.walk([](bufferization::ToTensorOp toTensorOp) { + toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer()); + toTensorOp.erase(); + }); + + // Change main function return types from tensors to memrefs + func::FuncOp funcOp; + for (Operation& op : moduleOp.getBody()->getOperations()) + if ((funcOp = dyn_cast(&op))) + break; + auto oldFuncType = funcOp.getFunctionType(); + SmallVector newResults; + bool changed = false; + for (Type type : oldFuncType.getResults()) + if (auto tensorType = dyn_cast(type)) { + newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType())); + changed = true; + } + else + newResults.push_back(type); + if (changed) + funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults)); + + annotateWeightsMemrefs(moduleOp, funcOp); + + // Dump to file for debug + ModuleOp module = getOperation(); + std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); + std::filesystem::create_directory(outputDir); + std::fstream file(outputDir + "/pim_buf.mlir", std::ios::out); + llvm::raw_os_ostream os(file); + os << *module; + os.flush(); + file.close(); +} + +void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { + MLIRContext* ctx = funcOp.getContext(); + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { + bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }); + if (isAlwaysWeight) { + auto globalMemrefOp = moduleOp.lookupSymbol(getGlobalOp.getName()); + assert("Weights must be constants" && globalMemrefOp.getConstant()); + getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx)); + globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx)); + } + }); +} + +} // namespace pim + +} // namespace onnx_mlir diff --git a/src/PIM/Transforms/PimBufferizationPass.hpp b/src/PIM/Transforms/PimBufferizationPass.hpp new file mode 100644 index 0000000..3462b48 --- /dev/null +++ b/src/PIM/Transforms/PimBufferizationPass.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "mlir/Pass/Pass.h" + +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerOptions.hpp" + +namespace onnx_mlir { + +namespace pim { + +struct PimBufferizationPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) + StringRef getArgument() const override { return "bufferize-pim"; } + StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; } + + PimBufferizationPass() = default; + PimBufferizationPass(const PimBufferizationPass& pass) {} + + void runOnOperation() final; + +private: + void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const; +}; + +} // namespace pim + +std::unique_ptr createBufferizePimPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/test/PIM/CMakeLists.txt b/test/PIM/CMakeLists.txt new file mode 100644 index 0000000..e69de29