add PIM accelerator

This commit is contained in:
NiccoloN
2026-02-24 15:09:18 +01:00
parent b24a0df8d7
commit a6e928bdd7
67 changed files with 9109 additions and 1 deletions

View File

@@ -3,4 +3,31 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor) project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir
function(raptor_ensure_symlink link_path target_path)
get_filename_component(link_parent "${link_path}" DIRECTORY)
if(NOT EXISTS "${link_parent}")
message(FATAL_ERROR "Directory not found: ${link_parent}")
endif()
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
SYMBOLIC
)
endif()
endfunction()
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
)
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
)
add_subdirectory(onnx-mlir) add_subdirectory(onnx-mlir)

46
src/PIM/CMakeLists.txt Normal file
View File

@@ -0,0 +1,46 @@
set(PIM_ENABLED 1 BOOL PARENT_SCOPE)
set(PIM_SRC_ROOT "${CMAKE_CURRENT_SOURCE_DIR}")
set(PIM_BIN_ROOT "${CMAKE_CURRENT_BINARY_DIR}")
set(PIM_LIBRARY_PATH ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
set(PIM_RUNTIME_PATH ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
add_subdirectory(Dialect)
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_subdirectory(Common)
add_onnx_mlir_library(OMPIMAccel
PimAccelerator.cpp
Transforms/PimBufferizationPass.cpp
Pass/MessagePass.cpp
Pass/CountInstructionPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
LINK_LIBS PUBLIC
onnx
OMAccelerator
OMPimCompilerUtils
OMCompilerUtils
OMONNXOps
SpatialOps
PimOps
OMONNXToSpatial
OMSpatialToGraphviz
OMSpatialToPIM
OMPIMCommon
)

View File

@@ -0,0 +1,19 @@
add_onnx_mlir_library(OMPIMCommon
PIMCommon.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
LINK_LIBS PUBLIC
onnx
OMPimCompilerUtils
SpatialOps
PimOps
)

View File

@@ -0,0 +1,67 @@
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
llvm::FailureOr<Operation *> getOtherEndOfChannel(
Operation *op, bool opIsReceive, RewriterBase &rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError(
"User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation *notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
} else if (secondUser == op) {
notOpUser = firstUser;
} else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
} else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,16 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/StringRef.h"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME =
"pim.constant.should_allocate";
namespace onnx_mlir {
llvm::FailureOr<mlir::Operation *> getOtherEndOfChannel(
mlir::Operation *op, bool opIsReceive, mlir::RewriterBase &rewriter);
} // namespace onnx_mlir

View File

@@ -0,0 +1,44 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
template <typename T>
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
map.erase(result);
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
map.erase(arg);
}
};
template <typename T>
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
NotErasableValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
}
};

View File

@@ -0,0 +1,44 @@
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
add_onnx_mlir_library(OMPimCompilerOptions
PimCompilerOptions.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
LINK_LIBS PUBLIC
${OMLibs}
OMCompilerOptions
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
add_onnx_mlir_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimCodeGen.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
LINK_LIBS PUBLIC
${OMLibs}
OMCompilerUtils
OMPimCompilerOptions
OMCompilerPasses
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)

View File

@@ -0,0 +1,704 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
namespace onnx_mlir {
MemEntry* PimMemory::gatherMemEntry(Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first;
}
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size;
// Alignment
if (size_t remainder = firstAvailableAddress % minAlignment)
firstAvailableAddress += minAlignment - remainder;
globalMemEntriesMap[value] = memEntry;
}
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// More than one SSA value per single global constant:
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
llvm::SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!getGlobalOp->hasAttr("weightAlways")) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
auto iter = globalConstants.find(globalMemrefOp);
if (iter == globalConstants.end())
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
else {
MemEntry memEntry = *iter->second;
globalMemEntriesMap[getGlobalOp] = memEntry;
}
}
});
for (Value arg : funcOp.getArguments())
gatherMemEntry(arg);
allocateCore(funcOp);
}
void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
}
MemEntry PimMemory::getMemEntry(Value value) const {
auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second;
}
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(Value value) const {
while (true) {
auto definingOp = value.getDefiningOp();
if (!definingOp)
break;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
break;
value = tiedOperand->get();
}
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto source = subviewDefiningOp.getSource();
auto srcShape = source.getType().getShape();
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
value = source;
}
else
break;
}
return memEntriesMap.at(value).address;
}
llvm::json::Object PimCodeGen::createSetImmediate(size_t targetRegister, size_t immediate) {
llvm::json::Object returnValue;
returnValue["op"] = "sldi";
returnValue["rd"] = targetRegister;
returnValue["imm"] = immediate;
return returnValue;
}
llvm::json::Object PimCodeGen::createEmptyOffset() {
llvm::json::Object returnValue;
returnValue["offset_select"] = 0;
returnValue["offset_value"] = 0;
return returnValue;
}
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) {
llvm::json::Object setRegisterJson = createSetImmediate(registerNumber, immediate);
coreFileStream << llvm::json::Value(std::move(setRegisterJson)) << ',';
}
void PimCodeGen::createRd(size_t rdAddress, size_t rdOffset) {
// rd on register 0
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
}
void PimCodeGen::createRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) {
// rd on register 0
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
// rs1 on register 1
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
}
void PimCodeGen::createRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) {
// rd on register 0
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
// rs1 on register 1
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
// rs2 on register 2
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) {
auto deviceDst = loadOp.getDeviceDst();
auto hostSrc = loadOp.getHostSrc();
auto deviceDstOffset = loadOp.getDeviceDstOffset();
auto hostSrcOffset = loadOp.getHostSrcOffset();
auto size = loadOp.getSize();
auto deviceDstAlloc = memory.getValueAddress(deviceDst);
auto hostSrcAlloc = memory.getValueAddress(hostSrc);
// Set load rd register (reg 0)
createRdRs1(deviceDstAlloc, deviceDstOffset, hostSrcAlloc, hostSrcOffset);
llvm::json::Object loadOpJson;
loadOpJson["op"] = "ld";
loadOpJson["rd"] = 0;
loadOpJson["rs1"] = 1;
loadOpJson["size"] = size;
loadOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(loadOpJson)) << ',';
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) {
auto hostDst = storeOp.getHostDst();
auto deviceSrc = storeOp.getDeviceSrc();
auto hostDstOffset = storeOp.getHostDstOffset();
auto deviceSrcOffset = storeOp.getDeviceSrcOffset();
auto size = storeOp.getSize();
auto deviceSrcAlloc = memory.getValueAddress(deviceSrc);
auto hostDstAlloc = memory.getValueAddress(hostDst);
// Set load rd register (reg 0)
createRdRs1(hostDstAlloc, hostDstOffset, deviceSrcAlloc, deviceSrcOffset);
llvm::json::Object storeOpJson;
storeOpJson["op"] = "st";
storeOpJson["rd"] = 0;
storeOpJson["rs1"] = 1;
storeOpJson["size"] = size;
storeOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(storeOpJson)) << ',';
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
auto outBufAlloc = memory.getValueAddress(mvmLikeOp.getOutBuf());
auto vectorAlloc = memory.getValueAddress(mvmLikeOp.getVectorInput());
createRdRs1(outBufAlloc, 0, vectorAlloc, 0);
llvm::json::Object mvmOpJson;
mvmOpJson["op"] = "mvmul";
mvmOpJson["rd"] = 0;
mvmOpJson["rs1"] = 1;
mvmOpJson["group"] = mvmId;
mvmOpJson["relu"] = 0;
mvmOpJson["mbiw"] = 8;
coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ',';
// TODO: save weights somewhere (if transposeMatrix=true, then transpose the
// weight matrix)
}
void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) {
auto outBuff = memory.getValueAddress(applyFiltersOp.getOutBuf());
auto inBuff = memory.getValueAddress(applyFiltersOp.getInput());
auto accumBuff = memory.getValueAddress(applyFiltersOp.getAccumBuf());
// Get weight indices from the operation attribute.
auto weightIndices = applyFiltersOp.getWeightIndices();
// Get shape of the input tensor.
auto inputType = cast<MemRefType>(applyFiltersOp.getInput().getType());
auto outputType = cast<MemRefType>(applyFiltersOp.getOutBuf().getType());
auto in_shape = inputType.getShape();
auto out_shape = outputType.getShape();
// Extract the relevant dimensions.
size_t in_channels = in_shape[1]; // Number of input channels.
size_t out_channels = out_shape[1]; // Number of output channels.
size_t dim2 = in_shape.size() > 2 ? in_shape[2] : 1; // Image width.
size_t dim3 = in_shape.size() > 3 ? in_shape[3] : 1; // Image height.
// Iterate through pixels.
for (size_t out_y = 0; out_y < dim3; out_y++) {
for (size_t out_x = 0; out_x < dim2; out_x++) {
// For each crossbar, perform the MVMUL operation.
size_t weightIndex = 0;
for (Attribute weight : weightIndices) {
// --------------------------------------
// --- STEP 1: Perform MVUL operation ---
// --------------------------------------
// Get the weight matrix ID for this position.
auto weightId = cast<IntegerAttr>(weight).getInt();
size_t xKer = cast<IntegerAttr>(applyFiltersOp.getXKernelPositions()[weightIndex]).getInt();
size_t yKer = cast<IntegerAttr>(applyFiltersOp.getYKernelPositions()[weightIndex]).getInt();
weightIndex++;
if (out_x + xKer >= dim2 || out_y + yKer >= dim3)
continue;
// Calculate the offset for the input (and output) tensor.
size_t output_offset = (out_y * dim2 + out_x) * 32 * out_channels;
size_t input_offset = ((out_y + yKer) * dim2 + (out_x + xKer)) * 32 * in_channels;
// Read from the input tensor and store the partial result in the
// accumulator buffer, if this is not the first weight matrix.
// Note that rs1 is the input tensor, and rd is the output tensor.
// TODO: This order of arguments is confusing, check if the correct
// order is being used in the WMVUL operation. The order below is
// correct.
if (weightIndices[0] != weight) {
createRdRs1(accumBuff, 0, inBuff, input_offset);
}
else {
// Otherwise store directly in the output buffer.
createRdRs1(outBuff, output_offset, inBuff, input_offset);
}
// Create the MVMUL JSON object
llvm::json::Object mvmOpJson;
mvmOpJson["op"] = "mvmul";
mvmOpJson["rd"] = 0;
mvmOpJson["rs1"] = 1;
mvmOpJson["group"] = weightId;
mvmOpJson["relu"] = 0;
mvmOpJson["mbiw"] = 8;
// Write the JSON to the output stream
coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ',';
// --------------------------------------
// --- STEP 2: Perform VADD operation ---
// --------------------------------------
// If this is the first weight matrix, we don't need to perform a VADD.
if (weightIndices[0] == weight)
continue;
// We now need to sum the value in the accumulator buffer with the value
// in the output buffer, and store the result in the output buffer.
createRdRs1Rs2(outBuff, output_offset, accumBuff, 0, outBuff, output_offset);
llvm::json::Object vaddOpJson;
vaddOpJson["op"] = "vvadd";
vaddOpJson["rd"] = 0;
vaddOpJson["rs1"] = 1;
vaddOpJson["rs2"] = 2;
vaddOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(vaddOpJson)) << ',';
}
}
}
}
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) {
auto outBufAlloc = memory.getValueAddress(vaddOp.getOutBuf());
auto rs1BufferOp = memory.getValueAddress(vaddOp.getA());
auto rs2BufferOp = memory.getValueAddress(vaddOp.getB());
createRdRs1Rs2(outBufAlloc, 0, rs1BufferOp, 0, rs2BufferOp, 0);
// Get the size of the output buffer.
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
auto out_shape = outputType.getShape();
// Multiply all dimension lengths to get the total number of elements.
size_t totalElements = 1;
for (size_t i = 0; i < out_shape.size(); i++)
totalElements *= out_shape[i];
auto elementSize = vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
llvm::json::Object mvmOpJson;
mvmOpJson["op"] = "vvadd";
mvmOpJson["rd"] = 0;
mvmOpJson["rs1"] = 1;
mvmOpJson["rs2"] = 2;
mvmOpJson["offset"] = createEmptyOffset();
mvmOpJson["len"] = totalElements * elementSize;
coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ',';
}
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) {
auto outBufAlloc = memory.getValueAddress(vmaxOp.getOutBuf());
auto rs1BufferOp = memory.getValueAddress(vmaxOp.getA());
auto rs2BufferOp = memory.getValueAddress(vmaxOp.getB());
createRdRs1Rs2(outBufAlloc, 0, rs1BufferOp, 0, rs2BufferOp, 0);
llvm::json::Object mvmOpJson;
mvmOpJson["op"] = "vvmax";
mvmOpJson["rd"] = 0;
mvmOpJson["rs1"] = 1;
mvmOpJson["rs2"] = 2;
mvmOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ',';
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) {
auto outBufAlloc = memory.getValueAddress(vreluOp.getOutBuf());
auto rs1BufferOp = memory.getValueAddress(vreluOp.getA());
createRdRs1(outBufAlloc, 0, rs1BufferOp, 0);
llvm::json::Object mvmOpJson;
mvmOpJson["op"] = "vrelu";
mvmOpJson["rd"] = 0;
mvmOpJson["rs1"] = 1;
mvmOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(mvmOpJson)) << ',';
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) {
auto destAlloc = memory.getValueAddress(receiveOp.getDst());
createRd(destAlloc, /* dest_offset = */ 0);
llvm::json::Object recvOpJson;
recvOpJson["op"] = "recv";
recvOpJson["rd"] = 0;
recvOpJson["core"] = receiveOp.getSrcCoreId();
recvOpJson["size"] = receiveOp.getSize();
recvOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(recvOpJson)) << ',';
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) {
auto srcAlloc = memory.getValueAddress(sendOp.getSrc());
// Technically a RS1 register, but its just a name..
createRd(srcAlloc, /* dest_offset = */ 0);
llvm::json::Object sendOpJson;
sendOpJson["op"] = "send";
sendOpJson["rd"] = 0;
sendOpJson["core"] = sendOp.getTargetCoreId();
sendOpJson["size"] = sendOp.getSize();
sendOpJson["offset"] = createEmptyOffset();
coreFileStream << llvm::json::Value(std::move(sendOpJson)) << ',';
}
size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape");
return std::max(matrixShape.getDimSize(0), matrixShape.getDimSize(1));
}
std::string getMemorySizeAsString(size_t size) {
if (size > 1024 * 1024 * 1024)
return std::to_string(size / 1024 / 1024 / 1024) + " GB";
if (size > 1024 * 1024)
return std::to_string(size / 1024 / 1024) + " MB";
if (size > 1024)
return std::to_string(size / 1024) + " KB";
return std::to_string(size) + " Bytes";
}
int compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& moduleOpRef, std::string& outputDirPath) {
ModuleOp moduleOp = moduleOpRef.get();
if (pimEmissionTarget != EmitPimCodegen) {
moduleOp.dump();
return CompilerSuccess;
}
if (!outputDirPath.empty()) {
if (auto error = llvm::sys::fs::create_directory(outputDirPath)) {
llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
// For each core, specify the number of crossbar per array group
// This implementation always assigns one crossbar per group
llvm::json::Object xbarsPerArrayGroup;
auto funcOps = moduleOp.getOps<func::FuncOp>();
assert(!funcOps.empty() && "No function found in the module");
auto funcOp = *funcOps.begin();
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);
// Write memory binary file
auto memoryFilePath = outputDirPath + "/memory.bin";
std::error_code errorCode;
llvm::raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, llvm::sys::fs::OF_None);
if (errorCode) {
llvm::errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// Zero-initialized buffer
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
// Write global values at their allocated addresses
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (getGlobalOp->hasAttr("weightAlways"))
return;
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
if (!globalOp)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
auto memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
auto rawData = denseAttr.getRawData();
std::memcpy(memoryBuffer.data() + memEntry.address, rawData.data(), std::min(rawData.size(), memEntry.size));
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
size_t coreCount = 0;
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
auto coreId = coreOp.getCoreId();
coreCount++;
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
llvm::raw_fd_ostream coreFileStream(outputCorePath, errorCode);
if (errorCode) {
llvm::errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
coreFileStream << '[';
auto coreNameString = "core" + std::to_string(coreId);
PimCodeGen coreCodeGen(memory, coreFileStream);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp>(op))
continue;
if (isa<pim::PimHaltOp>(op))
continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
coreCodeGen.codeGenLoadOp(loadOp);
}
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
coreCodeGen.codeGenStoreOp(storeOp);
}
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
}
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op)) {
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
}
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op)) {
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
}
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op)) {
coreCodeGen.codeGenVAddOp(vaddOp);
}
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op)) {
coreCodeGen.codeGenVMaxOp(vmaxOp);
}
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op)) {
coreCodeGen.codeGenVReluOp(vreluOp);
}
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op)) {
coreCodeGen.codeGenReceiveOp(receiveOp);
}
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op)) {
coreCodeGen.codeGenSendOp(sendOp);
}
else if (auto sumOp = dyn_cast<pim::PimSumOp>(op)) {
// TODO: Implement somehow?
op.emitWarning("Sum operation is not supported");
continue;
}
else if (auto vsDivOp = dyn_cast<pim::PimVSDivOp>(op)) {
// TODO: Implement somehow?
op.emitWarning("VSDiv operation is not supported");
continue;
}
else if (auto vexpOp = dyn_cast<pim::PimVExpOp>(op)) {
// TODO: Implement somehow?
op.emitWarning("VExp operation is not supported");
continue;
}
else if (isa<memref::SubViewOp>(op)) {
continue;
}
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
return CompilerFailure;
}
processedOperations++;
}
assert(processedOperations > 0);
// Remove trailing comma
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
// Create output directory for this core's crossbar weights
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = llvm::sys::fs::create_directory(coreWeightsDirPath)) {
llvm::errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
int64_t xbarSize = crossbarSize.getValue();
size_t weightIndex = 0;
llvm::json::Array xbarsPerGroup;
for (auto weight : coreOp.getWeights()) {
xbarsPerGroup.push_back(weightIndex);
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr());
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
auto elementType = type.getElementType();
size_t elementByteWidth = elementType.getIntOrFloatBitWidth() / 8;
// Write crossbar weights as binary, padded to crossbarSize x crossbarSize
auto weightFilePath = coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin";
llvm::raw_fd_ostream weightFileStream(weightFilePath, errorCode, llvm::sys::fs::OF_None);
if (errorCode) {
llvm::errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
weightIndex++;
}
xbarsPerArrayGroup[coreNameString] = std::move(xbarsPerGroup);
}
// Step 3: Write configuration to JSON
llvm::json::Object configJson;
configJson["core_cnt"] = coreCount;
// TODO: Should this be based on the floating point type used in the model?
//// The 2 following values determine the bitwidth of the vectors' elements:
//// bitwidth = adc_count * cell_precision
// Number of ADC for MVM units
configJson["adc_count"] = 16;
// Bit precision of each ADC
configJson["cell_precision"] = 2;
//// Crossbar configuration
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
// Store the crossbar sizes
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
// Store the memory layout of inputs and outputs
llvm::json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
llvm::json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
// Step 4: Write config JSON
std::string openOutputErrorMsg;
auto configPath = outputDirPath + "/config.json";
std::error_code EC;
llvm::raw_fd_ostream jsonOS(configPath, EC);
if (EC) {
llvm::errs() << "Error while opening config file: " << EC.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << llvm::json::Value(std::move(configJson)) << '\n';
jsonOS.close();
showCompilePhase("Code generated into " + configPath);
return CompilerSuccess;
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,97 @@
#pragma once
#include "llvm/Support/JSON.h"
#include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp"
namespace onnx_mlir {
struct MemEntry {
size_t address;
size_t size;
};
class PimMemory {
SmallVector<std::pair<MemEntry, Value>, 32> memEntries;
llvm::SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(Value value);
void allocateMemoryForValue(Value value, MemEntry& memEntry);
public:
PimMemory(llvm::SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp);
void allocateCore(Operation* op);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(Value value) const ;
};
class PimAcceleratorMemory {
public:
llvm::SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {}
PimMemory getOrCreateDeviceMem(size_t id);
size_t getValueAddress(Value value) const;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream;
public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
llvm::json::Object createSetImmediate(size_t targetRegister, size_t immediate);
llvm::json::Object createEmptyOffset();
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate);
void createRd(size_t rdAddress, size_t rdOffset);
void createRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset);
void createRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset);
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp);
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp);
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenReceiveOp(pim::PimReceiveOp receiveOp);
void codeGenSendOp(pim::PimSendOp sendOp);
void codeGenVAddOp(pim::PimVAddOp vaddOp);
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp);
void codeGenVReluOp(pim::PimVReluOp vreluOp);
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp);
};
} // namespace onnx_mlir

View File

@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===------------------------- PimCompilerOptions.cpp --------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Compiler Options for PIM
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::desc("[Optional] Choose PIM-related target to emit "
"(once selected it will cancel the other targets):"),
llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")),
llvm::cl::values(clEnumVal(EmitPim, "Lower model to PIM IR")),
llvm::cl::values(
clEnumVal(EmitPimBufferized, "Lower model to PIM IR and bufferize it")),
llvm::cl::values(clEnumVal(EmitPimCodegen, "Lower model to PIM IR and "
"generate code for PIM")),
llvm::cl::init(EmitPimCodegen), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in "
"bufferized PIM IR)"),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::desc("Use experimental implementation for convolution"),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t> crossbarSize("crossbar-size",
llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t> crossbarCountInCore("crossbar-count",
llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum "
"amount of cores."),
llvm::cl::init(-1));
llvm::cl::opt<bool> ignoreConcatError("ignore-concat-error",
llvm::cl::desc(
"Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false));
} // namespace onnx_mlir

View File

@@ -0,0 +1,42 @@
#pragma once
#include "llvm/Support/CommandLine.h"
#define INSTRUMENTSTAGE_ENUM_PIM
#define INSTRUMENTSTAGE_CL_ENUM_PIM
#define PROFILEIR_CL_ENUM_PIM
#define OPTREPORT_ENUM_PIM
#define OPTREPORT_CL_ENUM_PIM
namespace onnx_mlir {
typedef enum {
EmitSpatial = 0,
EmitPim = 1,
EmitPimBufferized = 2,
EmitPimCodegen = 3
} PimEmissionTargetType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<onnx_mlir::PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> exportCrossbarWeights;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
// This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the
// wanted tile is generated by two separate operands of the concat. If this is
// set to false, this corner case will assert an error. If this is set to true,
// a simplification is performed and only the tile from the first operand is
// taken.
extern llvm::cl::opt<bool> ignoreConcatError;
} // namespace onnx_mlir

View File

@@ -0,0 +1,56 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/Passes.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "llvm/Support/JSON.h"
#include <cassert>
#include <cstddef>
#define DEBUG_TYPE "PimCompilerUtils"
using namespace mlir;
using namespace onnx_mlir;
namespace onnx_mlir {
void addPassesPim(OwningOpRef<ModuleOp>& module,
PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt) {
if (pimOnlyCodegen) {
// Skip all the lowering passes and directly generate code for PIM.
return;
}
if (emissionTarget >= EmitONNXIR)
addONNXToMLIRPasses(pm, /*target CPU*/ false);
if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("ONNX lowered to SPATIAL"));
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPIMPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("SPATIAL lowered to PIM"));
}
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createBufferizePimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("PIM bufferized"));
}
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,19 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
namespace onnx_mlir {
void addPassesPim(mlir::OwningOpRef<mlir::ModuleOp>& module,
mlir::PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt);
int compileModuleToPIMJSON(const mlir::OwningOpRef<mlir::ModuleOp>& moduleOpRef,
std::string& outputDirName);
} // namespace onnx_mlir

View File

@@ -0,0 +1,3 @@
add_subdirectory(ONNXToSpatial)
add_subdirectory(SpatialToGraphviz)
add_subdirectory(SpatialToPIM)

View File

@@ -0,0 +1,34 @@
set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td)
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.cpp
Math/Conv.cpp
Math/ExperimentalConv.cpp
Math/ExperimentalGemm.cpp
NN/Pooling.cpp
NN/ExperimentalPooling.cpp
NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp
Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp
ONNXToSpatialPass.hpp
ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp
DEPENDS
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPimCompilerOptions
OMONNXOps
SpatialOps
OMPIMCommon
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -0,0 +1,624 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <cstddef>
#include <memory>
#include <unordered_map>
#include <vector>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
// NOTE:
// This might be useful to re-implement this considering for loops.
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
/**
* @brief A momentary representation of a core, to be used within the tiling of
* a convolution operation.
*/
class Core {
public:
Core(const size_t coreId, ConversionPatternRewriter &rewriter)
: coreId(coreId), rewriter(rewriter) {}
/**
* @brief Add a MVM operation to the core.
*
* @param inputTile The input tile to the MVM operation.
* @param xbarIndex The index of the crossbar weight to use.
* @param outputTileId The id of the output tile.
* @param mvmOutType The result's shape.
* @return Value The result of the MVM operation.
*/
Value addMVM(
Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc();
// Move the insertion point to the end of the block.
rewriter.setInsertionPointToEnd(block.get());
// Add the inputTile to the block arguments, and to the operands.
Value operand = operandMap.lookupOrNull(inputTile);
if (not operand) {
operand = block->addArgument(inputTile.getType(), loc);
operands.push_back(inputTile);
operandMap.map(inputTile, operand);
}
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
// is correct.
// Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(
loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple
// MVM operations for the same outputTile.
auto lastMVM = outputTileToMVM.find(outputTileId);
// If an entry for this outputTile already exists, apply reduction.
if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>(
loc, mvmOutType, lastMVM->second, result);
}
outputTileToMVM[outputTileId] = result;
return result;
}
/**
* @brief Mark a result as remappable, and return a shared pointer to it.
*
* This function marks a result as remappable, and returns a shared pointer to
* it. We need to keep track of these values to generate the YieldOp at a
* later stage.
*
* @param result A result to track, for later remapping.
* @return shared_ptr<Value> A shared pointer to the result.
*/
shared_ptr<Value> makeResultRemappable(Value result) {
// Verify that the result is present in the block.
assert(result.getDefiningOp()->getBlock() == block.get());
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
resultsToRemap.push_back(remappableResult);
results.push_back(result);
return remappableResult;
}
/**
* @brief Add a remappable operand to the core, to merge partial results
* inter-core.
*
* @param remappableOperand The operand to add.
* @return Value The block argument representing the operand.
*/
Value addRemappableOperand(std::shared_ptr<Value> operand) {
// Check that the operand is not already there.
assert(not operandMap.contains(*operand));
Value argument = block->addArgument(operand->getType(), operand->getLoc());
remappableOperands.push_back(operand);
return argument;
}
/**
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
*
* @param loc The location of the operation.
* @return spatial::SpatWeightedCompute
*/
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results.
SmallVector<Type> resultTypes;
for (const auto &value : results) {
resultTypes.push_back(value.getType());
}
// Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(
loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp.
Block *releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results.
rewriter.setInsertionPointToEnd(releasedBlock);
rewriter.create<spatial::SpatYieldOp>(loc, results);
return wcomputeOp;
}
/**
* @brief Remap the results to the WComputeOp results.
*/
void remapResults() {
// Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++) {
*resultsToRemap[i] = wcomputeOp.getResult(i);
}
}
void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands) {
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
}
// Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize(
wcomputeOp, static_cast<int>(remappableOperands.size()));
}
size_t addXbarWeight(Value weight) {
assert(!isXbarsFull());
xbarWeights.push_back(weight);
return xbarWeights.size() - 1;
}
bool isXbarsFull() {
assert(xbarWeights.size() <= crossbarCountInCore);
return xbarWeights.size() == crossbarCountInCore;
}
bool isCoreEmpty() { return block->empty(); }
void dump() {
// Print the coreId
llvm::outs() << "Core " << coreId << ":\n";
// Print the weights
llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights) {
weight.dump();
}
// Print the operands
llvm::outs() << "Operands:\n";
for (auto operand : operands) {
llvm::outs() << operand << "\n";
}
// Dump the body block
for (auto &op : block->getOperations()) {
op.dump();
}
// Print the results
llvm::outs() << "Results:\n";
for (auto result : results) {
llvm::outs() << result << "\n";
}
}
const size_t coreId;
private:
ConversionPatternRewriter &rewriter;
// Should these be set<Value> instead? But I need to keep the order
vector<Value> operands;
vector<std::shared_ptr<Value>> remappableOperands;
vector<Value> results;
vector<std::shared_ptr<Value>> resultsToRemap;
// Maps from input tiles to the block operand
IRMapping operandMap;
// Map from outputTileId to MVM operation producing it
unordered_map<size_t, Value> outputTileToMVM;
vector<Value> xbarWeights;
unique_ptr<mlir::Block> block = make_unique<Block>();
spatial::SpatWeightedCompute wcomputeOp;
};
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {}
struct Producer_t {
Value value;
shared_ptr<Core> core;
};
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError =
unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value()) {
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
}
// TODO: Pad value at beginning and end of each dimension could be
// different. We should handle this case.
// MapOperations mapOperation = MapOperations::None;
//
// // If we have just one user, and it is an activation funcion (or more in
// // general a mapping operation) just inline it in the computeOps
// auto firstUserOp = *conv->getUsers().begin();
// if (conv->hasOneUse()) {
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
//
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
// return rewriter.notifyMatchFailure(
// conv, "Softmax not supported as activation for convolutions.");
// }
// }
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
Location loc = conv.getLoc();
size_t inputTileCount =
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t outputTileCount =
ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
// Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(inputTileCount,
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(convAdaptor.getX(), inputTiles,
inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value()) {
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
}
SmallVector<OpFoldResult> strides =
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets =
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult>{
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
// Tile the weight tensor
// Weight tiles need to be indexed by:
// a. Filter Tile
// b. Channel Tile
// c. Kernel `x` position
// d. Kernel `y` position
// For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0) {
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
}
sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0) {
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
}
for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] = rewriter.create<tensor::ExtractSliceOp>(
loc, convAdaptor.getW(), offsets, sizes, strides);
}
}
}
}
/* Distribute the computation among many compute cores
* Try to compute in-core the computation for each output tile, and reduce
* over as few cores as possible
*/
// Tile the output tensor
// Output tiles need to be indexed by:
// a. Filter Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(
output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME)) {
replicationFactor = 1;
} else {
replicationFactor =
conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
}
// producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers =
vector<vector<vector<vector<Producer_t>>>>(outputTileCount,
vector<vector<vector<Producer_t>>>(output_w,
vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores
size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++) {
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize =
ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1},
bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0) {
mvmOutType = mvmOutType.clone(
{1, static_cast<long>(outputTileRemainder), 1, 1});
}
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++) {
xbarIndexes[i] = curCores[i]->addXbarWeight(
weightTiles[outTile][inTile][krn_x][krn_y]);
}
size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
size_t out_y = 0;
// I use `replicationFactor` cores. I divide the input_w into
// `replicationFactor` slices, and each slice is distributed to a
// core. `coreIndex` is the index of the core that will be used
// for this slice
size_t coreIndex = in_x / replicationSliceSize;
assert(coreIndex < replicationFactor);
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel
int actual_in_x = in_x - ((int)krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int)krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x,
actual_in_y, pad_x, pad_y)
.failed()) {
out_y++;
continue;
}
size_t outTileId =
outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y],
xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back(
{mvm, curCores[coreIndex]});
out_y++;
}
out_x++;
}
// Computations for these crossbars are done, check if the cores
// crossbars are fully used. If full, swap with new core
for (size_t i = 0; i < replicationFactor; i++) {
if (curCores[i]->isXbarsFull()) {
cores.emplace_back(std::move(curCores[i]));
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
}
}
}
}
}
for (auto &curCore : curCores) {
if (curCore->isCoreEmpty() == false) {
cores.emplace_back(std::move(curCore));
}
}
curCores.clear();
// Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
for (size_t out_x = 0; out_x < output_w; out_x++) {
for (size_t out_y = 0; out_y < output_h; out_y++) {
// First, check if some producers are within the same core. If this is
// true, `Core::addMVM` have already done the reduction within-core.
// This means that we only need to consider the last producer for that
// core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y]) {
withinCoreReducedProducers[producer.core->coreId] = producer;
}
// Now, we need to apply inter-core reduction
// Base case with one producer
if (withinCoreReducedProducers.size() == 1) {
// TODO: Add the bias and apply mapping (if present)
auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result
auto reducedValue =
singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
continue;
}
// TODO: This is a linear reduction, not a tree reduction. We can do
// better: a tree reduction would make more computations happen in
// parallel.
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
auto it = withinCoreReducedProducers.begin();
it++;
while (it != withinCoreReducedProducers.end()) {
Producer_t curProducer = it->second;
shared_ptr<Core> core1;
shared_ptr<Core> core2;
Value core1Value;
Value core2Value;
auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId &&
"We should have already applied within-core reduction, how "
"could we have same cores here?");
// Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) {
core1 = curProducer.core;
core1Value = curProducer.value;
core2 = lastProducer.core;
core2Value = lastProducer.value;
} else {
core1 = lastProducer.core;
core1Value = lastProducer.value;
core2 = curProducer.core;
core2Value = curProducer.value;
}
auto newCoreRes = core1->makeResultRemappable(core1Value);
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
rewriter.setInsertionPointAfterValue(core2Value);
Value vaddRes =
rewriter.create<spatial::SpatVAddOp>(core2Value.getLoc(),
core2Value.getType(), core2Value, secondCoreBlockArg);
lastProducer = {vaddRes, core2};
it++;
}
// TODO: Add the bias and apply mapping (if present)
// Use last producer as the final result
auto reducedValue =
lastProducer.core->makeResultRemappable(lastProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
}
}
}
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
rewriter.setInsertionPointAfter(conv);
spatial::SpatWeightedCompute lastWComputeOp;
for (auto &core : cores) {
lastWComputeOp = core->createWComputeOp(loc);
core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp);
}
for (auto &core : cores) {
core->addRemappedOperands();
}
// Set the insertion point after the last WComputeOp.
rewriter.setInsertionPointAfter(lastWComputeOp);
SmallVector<Value> tilesToConcat;
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
for (size_t outX = 0; outX < output_h; outX++)
for (size_t outY = 0; outY < output_w; outY++)
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(
loc, conv.getY().getType(), tilesToConcat);
// Value outputImage =
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
// If no mapping (activation) was applied, just replace ConvOp
// if (mapOperation == MapOperations::None) {
// rewriter.replaceOp(conv, outputImage);
// } else {
// // If mapping was applied, erase ConvOp and replace the mapping op
// rewriter.eraseOp(conv);
// rewriter.replaceOp(firstUserOp, outputImage);
// }
return success();
}
};
void populateTilingConvOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,430 @@
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cstddef>
#include <unistd.h>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A pattern to tile the convolution operation into a series of compute
* units, each one of which applies filters to a subset of the input
* tensor. Results are also reduced and concatenated to form the final
* output tensor.
*/
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor,
ConversionPatternRewriter &rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1 && "Batch size must be 1"
"for convolution.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 &&
"Replication is not yet supported for convolution.");
// TODO: Address bias addition.
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes,
slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups =
weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights) {
computeWeights.push_back(weight);
}
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end()) {
continue;
}
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot
? inputTileCount.rem
: crossbarSize;
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes,
slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) !=
outputTileIndices.end()) {
continue;
}
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) !=
globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot
? outputTileCount.rem
: crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType =
dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back(
RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute =
rewriter.create<spatial::SpatWeightedCompute>(conv.getLoc(),
computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands) {
block->addArgument(operand.getType(), conv->getLoc());
}
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices) {
localPartialResults[reductionTileIndex.first] =
block->getArgument(reductionTileIndex.second);
}
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType =
computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument =
block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i) {
weightIndices.push_back(group.startingCrossbarIndex + i);
}
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
auto offsets = extractSliceOp.getStaticOffsets();
xKerPos.push_back(offsets[2]);
yKerPos.push_back(offsets[3]);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result =
rewriter.create<spatial::SpatApplyFiltersOp>(conv.getLoc(), outputType,
weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) !=
localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(conv.getLoc(),
result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0);
++i) {
outputValues.push_back(globalPartialResults[i]);
}
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 &&
"No output values were generated for the convolution.");
// If the conv's user is a ReLU...
if (conv->hasOneUse()) {
Operation *user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu,
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution.
rewriter.eraseOp(conv);
return success();
}
}
// Return the final output.
rewriter.replaceOp(conv,
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success();
}
};
/**
* @brief Populate the tiling pattern for a convolution operation.
*
* @param patterns The pattern set to populate.
* @param ctx The MLIR context.
*/
void populateExperimentalTilingConvOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,400 @@
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include <cstdlib>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
struct ExperimentalGemmConversionPattern
: public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext *ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 &&
"Only batch size of 1 is supported for GEMM.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 &&
"Replication is not yet supported for GEMM.");
// TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] &&
"Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
size_t kernelWidth = 1;
size_t kernelHeight = 1;
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes,
slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups =
weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights) {
computeWeights.push_back(weight);
}
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end()) {
continue;
}
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot
? inputTileCount.rem
: crossbarSize;
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation.
auto extractSliceOp =
rewriter.create<tensor::ExtractSliceOp>(gemmOp.getLoc(),
adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) !=
outputTileIndices.end()) {
continue;
}
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) !=
globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot
? outputTileCount.rem
: crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType())
.getElementType();
computeOutputType.push_back(
RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute =
rewriter.create<spatial::SpatWeightedCompute>(gemmOp.getLoc(),
computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands) {
block->addArgument(operand.getType(), gemmOp->getLoc());
}
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices) {
localPartialResults[reductionTileIndex.first] =
block->getArgument(reductionTileIndex.second);
}
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType =
computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument =
block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i) {
weightIndices.push_back(group.startingCrossbarIndex + i);
}
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
xKerPos.push_back(0);
yKerPos.push_back(0);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(gemmOp.getLoc(),
outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr,
blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) !=
localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(gemmOp.getLoc(),
result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0);
++i) {
outputValues.push_back(globalPartialResults[i]);
}
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 &&
"No output values were generated for the GEMM operation.");
// Return the final output.
rewriter.replaceOp(gemmOp,
rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success();
}
};
void populateGemmToConvConversionPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,317 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
ONNXGemmOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location gemmLoc = gemmOp.getLoc();
Value a = adaptor.getA();
Value b = adaptor.getB();
Value c = adaptor.getC();
Value out = gemmOp.getY();
float alpha = adaptor.getAlpha().convertToFloat();
float beta = adaptor.getBeta().convertToFloat();
bool transA = adaptor.getTransA();
bool transB = adaptor.getTransB();
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(out.getType());
RankedTensorType cType = nullptr;
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
assert("Only support 2 tensor for C" && cType.getRank() == 2);
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
}
if (transB) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
}
if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
}
if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
}
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices;
auto outLastHSliceSize = cLastHSliceSize;
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
SmallVector<Value> cHSlices;
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
if (hasC)
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
RankedTensorType outHSliceType =
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
RankedTensorType outLastHSliceType =
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
SmallVector<Value> outHSlices;
outHSlices.reserve(outNumHSlices);
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
RankedTensorType currOutHSliceType = outHSliceType;
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
currOutHSliceType = outLastHSliceType;
SmallVector<Value> partialResults;
partialResults.reserve(coresPerVSlice);
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
SmallVector<Value> weights;
weights.reserve(aHSlices[coreId].size());
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId])
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
auto computeArgs = computeBlock->getArguments();
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(computeArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
vmmOutputs.push_back(
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp);
partialResults.push_back(computeOp.getResult(0));
}
if (hasC) {
Value cHSlice = cHSlices[outSliceId];
partialResults.push_back(cHSlice);
}
auto reduceComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
auto* reduceBlock = new Block();
for (auto partialResult : partialResults)
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
reduceComputeOp.getBody().push_back(reduceBlock);
rewriter.setInsertionPointToStart(reduceBlock);
auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp);
outHSlices.push_back(reduceComputeOp.getResult(0));
}
rewriter.setInsertionPoint(gemmOp);
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
rewriter.replaceOp(gemmOp, concatOp);
return success();
}
private:
/**
* Resolves the ONNXExpOp from the use chain of the given start value.
*
* This function traverses the use chain of the start value until it finds an
* ONNXExpOp. It returns the value of the ONNXExpOp.
*
* @param startValue The starting value of the use chain.
* @return The value of the ONNXExpOp found in the use chain.
*/
static Value resolveONNXExpOpFromUseChain(Value startValue) {
Value walker = startValue;
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
walker = walker.getDefiningOp()->getOperand(0);
assert(walker && walker.getDefiningOp()
&& "Unwinded the whole chain of operations while trying to "
"find ONNXExpOp, but did not find it");
}
// Make sure the dividend is actually produced by an ONNXExpOp
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
&& "Old output tile (softmax reducer) is not produced by an "
"ONNXExpOp");
return walker;
}
// Softmax is a special case, as it requires another reduction after the
// first one. In the cores, `applyReducePattern` already applied
// f(x) = exp(x) to each tile. This mean that now we just need to
// reduce-sum these tiles, and then divide each tile by the reduced sum,
// which is propagated back to the cores via a broadcast channel.
LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc) const {
// TODO: Check case with one compute op
// Cast vector of Value into vector of ComputeOp
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
}));
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
const TensorType scalarTensorType = tensorTypeBuilder;
reducer.applyReducePattern(
softmaxOpsToReduce,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
/* preprocess = */
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
[&](Value softmaxDivisor) {
// Signal that this is the compute with the softmax divisor
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
// Broadcast the divisor to all the cores
rewriter.setInsertionPointAfterValue(softmaxDivisor);
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
/*
* softmaxDividend = onnx.exp (...)
* sum = spat.SumOp(softmaxDividend)
* [following can be repeated N times, thus walk the use chain]
* softmaxDivisor = spat.sadd(sum, ...)
*/
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
// Make sure the dividend is actually produced by an ONNXExpOp
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
&& "Dividend of softmax reduction is not an ONNXExpOp");
// Do not divide here, divide after this
return softmaxDivisor;
});
// In all the cores, insert a ChannelRecvOp and divide the output tile by
// the reduced denominator.
outputOpsAndResNums.clear();
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
Value divisor;
// Check if this compute contains the softmax divisor: if so, find the
// ChannelBroadcastSendOp, otherwise receive the value from the channel
// using ChannelBroadcastReceiveOp
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
bool found = false;
for (auto broadcastOp :
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
assert(found == false
&& "More than one ChannelBroadcastSendOp in "
"compute? How is this possible?");
found = true;
divisor = broadcastOp.getData();
}
assert(found
&& "No ChannelBroadcastSendOp in compute where softmax "
"divisor was specified to be?");
}
else {
rewriter.setInsertionPoint(yieldOp);
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
}
// Walk the chain of operations until we find the ONNXExpOp: this is
// needed because some some may have a different amount of `VAddOp`s due
// to the tree reduction (e.g. some may have no VAddOp, some may have
// multiples)
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
rewriter.setInsertionPoint(yieldOp);
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
auto yieldOperandNum = yieldOp->getNumOperands();
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
}
return success();
}
};
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXGemmOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,327 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
using namespace mlir;
namespace onnx_mlir {
template <typename PoolOp>
bool hasPostProcessExperimentalPoolingWindow() {
return false;
}
template <>
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter &rewriter,
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp,
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber =
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor =
RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(
valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(
loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename ReductionOp>
Value reduceInputTiles(
SmallVector<Value> &inputTiles, ConversionPatternRewriter &rewriter) {
if (inputTiles.size() == 1) {
return inputTiles[0];
}
if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(inputTiles[0].getLoc(),
inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
}
SmallVector<Value> left(
inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(
inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>(
inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext *ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET") {
return rewriter.notifyMatchFailure(
poolOp, "auto_pad != NOTSET is deprecated.");
}
size_t pad_x, pad_y;
auto padUnpackError =
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value()) {
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
}
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
// Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat) {
return rewriter.notifyMatchFailure(
poolOp, "Expected input to be a tensor.ConcatOp");
}
// Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
// For each argument of the tensor.ConcatOp, resolve the input tiles.
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize =
it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(
loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile;
}
}
}
// Prepare the shape of the compute's output.
ldiv_t itc = tileCount;
SmallVector<Type> outputTileTypes;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType())
.getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
auto elementType =
dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back(
RankedTensorType::get(outputShapeArray, elementType));
}
}
}
// Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
inputTilesList.push_back(inputTiles[it][y][x]);
}
}
}
// Create a single compute to calculate the output.
auto computeOp = rewriter.create<spatial::SpatWeightedCompute>(
loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) +
x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(
computeOp->getOperand(tileIndex).getType(), loc);
}
}
}
// Begin writing in the block.
rewriter.setInsertionPointToStart(block);
// Go through all pooling blocks.
SmallVector<Value> outputTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
size_t start_x = x * stride_x;
size_t start_y = y * stride_y;
size_t end_x = std::min(start_x + krn_w, input_w);
size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky) {
for (size_t kx = start_x; kx < end_x; ++kx) {
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
}
}
auto reduceResult =
reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult = rewriter.create<spatial::SpatVSDivOp>(
loc, reduceResult.getType(), reduceResult, divisorValue);
}
outputTiles.push_back(reduceResult);
}
}
}
// Create a YieldOp to return the output tiles.
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
// Set the rewrite cursor right after the computeOp.
rewriter.setInsertionPointAfter(computeOp);
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) +
x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
}
}
}
// We'll now create spat.img.concat ops to concatenate the output tiles.
SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
imgConcatTiles.push_back(computeOutput[it][y][x]);
}
}
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long)tilingSize,
/* 2 */ (long)output_w,
/* 3 */ (long)output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(loc,
RankedTensorType::get(outputShapeArray, elementType),
imgConcatTiles));
}
// Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor =
rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor);
return success();
}
};
void populateExperimentalPoolingTilingPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp,
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp,
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,452 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
using namespace mlir;
namespace onnx_mlir {
llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
ConversionPatternRewriter &rewriter,
std::function<Value(const Value &, const Value &)> reduce,
std::function<Value(const Value &)> preprocess,
std::function<Value(const Value &)> postprocess) {
// Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1) {
return valuesToReduce[0];
}
if (preprocess) {
for (auto &valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce);
}
}
// It is possible that `valuesToReduce` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation *, Value> lastValueForCompute;
for (auto &valToReduce : valuesToReduce) {
Operation *computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
// } else {
// // Otherwise it is a block argument,
// computeOp->getBlock()->getParentOp();
// }
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
auto it = lastValueForCompute.find(computeOp);
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock(
lastWithinComputeValue.getDefiningOp())) {
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
} else {
rewriter.setInsertionPointAfterValue(valToReduce);
}
valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce;
}
lastValueForCompute[computeOp] = valToReduce;
}
// Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size());
for (auto &entry : lastValueForCompute) {
valuesToReduce.push_back(entry.second);
}
Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the valuesToReduce list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
while (valuesToReduceRef.size() > 1) {
SmallVector<Value> nextValuesToReduce;
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
auto firstValue = valuesToReduceRef[i];
auto secondValue = valuesToReduceRef[i + 1];
auto firstCompute = firstValue.getParentBlock()->getParentOp();
auto secondCompute = secondValue.getParentBlock()->getParentOp();
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstValue, secondValue);
std::swap(firstCompute, secondCompute);
}
// 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute);
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
// 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue);
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue);
// 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(
loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue);
Value reduced = reduce(receivedValue, secondValue);
nextValuesToReduce.push_back(reduced);
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (valuesToReduceRef.size() % 2 == 1) {
nextValuesToReduce.push_back(valuesToReduceRef.back());
}
// Replace the inputOps list with the new one.
valuesToReduceRef =
llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
}
assert(valuesToReduceRef.size() == 1 &&
"Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0];
if (postprocess) {
rewriter.setInsertionPointAfterValue(finalValue);
finalValue = postprocess(finalValue);
}
return finalValue;
}
template <typename PoolOp>
bool hasPostProcessPoolingWindow() {
return false;
}
template <>
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter &rewriter,
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>(
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp,
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber =
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor =
RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(
valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(
loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext *ctx) : OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET") {
return rewriter.notifyMatchFailure(
poolOp, "auto_pad != NOTSET is deprecated.");
}
size_t pad_x, pad_y;
auto padUnpackError =
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value()) {
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
}
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t channelTileCount =
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
// 1: Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(channelTileCount,
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(X, inputTiles, channelTileCount,
channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value()) {
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
}
// TODO: This requires a core for each input tile, which is not ideal. We
// can do better.
// If some input tiles come from the func.func operands, load
// them into a computeOp and yield them
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp =
inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
tileLoc, extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(), extractSliceOp.getResult());
Block *tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(
extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock);
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0);
}
}
}
}
// 2: Tile the output tensor
// Output tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(
output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel
SmallVector<Value> valuesToPool;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
// Each output pixel tile is computed by pooling a window of input
// pixel tiles
valuesToPool.clear();
size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end(
outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(
outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings(
input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++;
continue;
}
Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool;
if (auto computeProducer =
inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(
computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber);
} else if (auto receiveProducer =
inputTile
.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt =
getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
}
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData();
} else {
return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
}
valuesToPool.push_back(valueToPool);
}
}
assert(valuesToPool.size() != 0 &&
"Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???");
std::function<Value(const Value &)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(rewriter, loc, poolOp,
prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
};
}
Value reducedWithinCompute = applyReducePatternNew(
valuesToPool, rewriter,
[&](const Value lhs, const Value rhs) {
return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs);
},
nullptr, postProcessFn);
// Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the
// users computeOps
auto computeOpOfReduced = cast<spatial::SpatWeightedCompute>(
reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel = rewriter.create<spatial::SpatChannelNewOp>(
loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>(
loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(
loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue;
}
}
}
// TODO: outputTiles are not the results of the computeOps! We need to add
// them!
std::unordered_map<Operation *,
SmallVector<std::tuple<size_t, size_t, size_t, Value>>>
computeOpNeedingResults;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
auto outputTile = outputTiles[outTile][outX][outY];
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
}
computeOpNeedingResults[outputTileProducer].push_back(
std::make_tuple(outTile, outX, outY, outputTile));
}
}
}
Value outputImage =
createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage);
return success();
}
};
void populatePoolingTilingPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp,
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp,
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,90 @@
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ReduceMeanConversionPattern
: public OpConversionPattern<ONNXReduceMeanV13Op> {
ReduceMeanConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Get the input tensor.
Value inputTensor = adaptor.getData();
auto inputTensorType = cast<RankedTensorType>(inputTensor.getType());
// This pattern will substitute the ONNXReduceMeanV13Op with a
// ONNXAveragePoolOp with the same input tensor and an appropriate kernel
// shape and strides.
// To get the stride and shape of the kernel, we need to read the tensor
// shape.
int image_height = inputTensorType.getShape()[2];
int image_width = inputTensorType.getShape()[3];
// Define the kernel shape and strides.
SmallVector<int64_t> kernelShapeVals = {image_height, image_width};
SmallVector<int64_t> stridesVals = {image_height, image_width};
SmallVector<int64_t> dilationsVals = {1, 1};
// Set the pads to 0.
SmallVector<int64_t> padsVals = {0, 0, 0, 0};
// Create the ArrayAttrs
auto kernelShape = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto dilations = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
// Create the resulting tensor type.
auto resultType = RankedTensorType::get(
/*shape=*/{inputTensorType.getShape()[0], inputTensorType.getShape()[1],
1, 1},
/*elementType=*/inputTensorType.getElementType());
// Create the ONNXAveragePoolOp.
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
resultType, inputTensor, /*auto_pad=*/"NOTSET",
/*ceil_mode=*/0, /*count_include_pad=*/1, dilations,
/*kernel_shape=*/kernelShape,
/*pads=*/pads, /*strides=*/strides);
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
rewriter.replaceOp(reduceMean, averagePool.getResult());
return success();
}
};
void populateReduceMeanConversionPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ReduceMeanConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,79 @@
#ifndef ONNX_TO_SPATIAL
#define ONNX_TO_SPATIAL
#ifndef OP_BASE
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
#endif // OP_BASE
def onnxToArithConstantOp : Pat<
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
(Arith_ConstantOp $value)
>;
//===----------------------------------------------------------------------===//
// ONNXMatMulOp to ONNXGemmOp patterns
//===----------------------------------------------------------------------===//
def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C,
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
>;
def matMulToGemmPattern : Pat<
(ONNXMatMulOp:$matmulres $A, $B),
(
ONNXGemmOp $A, $B,
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
>;
//===----------------------------------------------------------------------===//
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
//===----------------------------------------------------------------------===//
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias.
def convAddToConvWithBiasPatternLeft : Pat<
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
def convAddToConvWithBiasPatternRight : Pat<
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
//===----------------------------------------------------------------------===//
// Operation to ignore (i.e. remove)
//===----------------------------------------------------------------------===//
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
def removeLRNPattern : Pat<
(ONNXLRNOp $A, $_, $_, $_, $_),
(replaceWithOperationOfValue $A)
>;
def HaveSameStaticShape: Constraint<
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
"Two tensors have the same static shape">;
def removeFlattenSameShapePattern : Pat<
(ONNXFlattenOp:$flattenOp $A, $axis),
(replaceWithOperationOfValue $A),
[(HaveSameStaticShape $flattenOp, $A)]
>; // Add closing parenthesis here
#endif // ONNX_TO_SPATIAL

View File

@@ -0,0 +1,499 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <optional>
#include <utility>
#include "ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (const auto size : shape)
sizes.push_back(rewriter.getIndexAttr(size));
sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis];
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
SmallVector<Value> slices;
slices.reserve(numSlices);
for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
if (i == numSlices - 1 && lastSliceSize != 0)
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides);
slices.push_back(slice);
}
return slices;
}
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1;
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
}
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
size_t coreId = sliceId / crossbarCountInCore;
slicesPerCore[coreId].push_back(slices[sliceId]);
}
return slicesPerCore;
}
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
size_t numHSlices = hSlices.size();
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
Value hSlice = hSlices[hSliceId];
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
size_t coreId = vSliceId / crossbarCountInCore;
Value vSlice = vSlices[vSliceId];
tiles[hSliceId][coreId].push_back(vSlice);
}
}
return tiles;
}
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult();
return rewriter.create<tensor::SplatOp>(loc, type, elementValue);
}
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
switch (mapOp) {
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input);
}
}
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2) {
if (auto unpackedStrides = valuesArray) {
value1 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[0]).getInt();
value2 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[1]).getInt();
}
else {
value1 = 1;
value2 = 1;
}
}
std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y) {
if (valuesArray.has_value()) {
auto pads = mlir::ArrayAttr(*valuesArray);
if (pads.size() != 4)
return "pads must have 4 elements.";
pad_x = cast<IntegerAttr>(pads[2]).getInt();
pad_y = cast<IntegerAttr>(pads[3]).getInt();
}
else {
// Default padding is 0 unless specified otherwise.
// https://onnx.ai/onnx/operators/onnx__Conv.html
pad_x = pad_y = 0;
}
return std::nullopt;
}
void tileImageTensorByChannel(Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
size_t tileSize,
ConversionPatternRewriter& rewriter) {
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
size_t input_h = GET_IMAGE_HEIGHT(imageShape);
size_t input_w = GET_IMAGE_WIDTH(imageShape);
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize);
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Location loc = imageTensor.getLoc();
for (size_t i = 0; i < tileCount; i++) {
if (i == tileCount - 1 && tileRest != 0)
sizes[1] = rewriter.getIndexAttr(tileRest);
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
offsets[1] = rewriter.getIndexAttr(i * tileSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides);
}
}
}
}
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
ConversionPatternRewriter& rewriter,
Location& loc,
Type outputType) {
// Populate the outputTiles for the concat in the given order:
// 1. Start top left pixel
// 2. Continue on its right pixel till the end of the row
// 3. Restart on the next row
size_t outputTileCount = outputTiles.size();
size_t output_w = outputTiles[0].size();
size_t output_h = outputTiles[0][0].size();
SmallVector<Value> tilesToConcat;
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
for (size_t outX = 0; outX < output_h; outX++)
for (size_t outY = 0; outY < output_w; outY++)
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat);
}
LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) {
if (inX < 0) {
assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding");
return failure();
}
if (inY < 0) {
assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding");
return failure();
}
if ((size_t) inX >= input_w || (size_t) inY >= input_h) {
assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds");
assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds");
return failure();
}
return success();
}
Value createExtractSliceImg(Value valToSlice,
size_t x,
size_t y,
size_t t,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
if (t == channelTileCount - 1 && channelTileRest != 0)
sizes[1] = rewriter.getIndexAttr(channelTileRest);
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
}
Value indexImgValue(Value v,
size_t x,
size_t y,
size_t t,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
ConversionPatternRewriter& rewriter) {
auto newV = rewriter.getRemappedValue(v);
if (newV)
v = newV;
if (!v.getDefiningOp())
return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (auto computeOp = v.getDefiningOp<spatial::SpatWeightedCompute>()) {
// We found the computeOp that produces the tile we want, just return this
// value.
// TODO: Should we assert that x,y,t are zero?
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero");
return v;
}
if (auto receiveOp = v.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
// This is a receiveOp, just return its value which will be resolved later
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero");
return v;
}
if (auto imgConcatOp = v.getDefiningOp<spatial::SpatImgConcatOp>()) {
auto imgConcatInput = imgConcatOp.getInputTile(x, y, t);
// TODO: Is this correct?
// Above we already index exactly the tile we want, so `x=y=t=0` in
// recursive call
return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter);
}
if (auto tensorConcatOp = v.getDefiningOp<tensor::ConcatOp>()) {
// This can be recursive.
// First, get the input tensors of the tensor.concatOp
// Then, find the input tensor that contains the tile we want
// Finally, recursive call asking for the tile
auto concatAxis = tensorConcatOp.getDim();
assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis");
assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis.");
SmallVector<size_t, 4> indexDims = {1, t * crossbarSize, x, y};
// Find the input tensor that contains the tile we want
size_t currentTile = 0;
for (auto concatInput : tensorConcatOp.getInputs()) {
auto concatInputShape = cast<ShapedType>(concatInput.getType());
assert(concatInputShape.getRank() == 4 && "Expecting an image tensor");
auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis);
if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) {
// This input tensor contains the tile we want
indexDims[concatAxis] -= currentTile;
if (indexDims[1] % crossbarSize != 0) {
assert(ignoreConcatError
&& "TODO: Handle non-tile aligned tensor, or set "
"--ignore-concat-error=true");
}
return indexImgValue(concatInput,
indexDims[2],
indexDims[3],
indexDims[1] / crossbarSize,
channelTileCount,
channelTileRest,
input_w,
input_h,
rewriter);
}
currentTile += concatInputSizeOnAxis;
}
assert(false
&& "Could not find the input tensor that contains the tile "
"within tensor.ConcatOp");
}
v.dump();
assert(false && "indexImgValue: unsupported operation");
}
void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Location loc = wholeInputTensor.getLoc();
for (size_t t = 0; t < channelTileCount; t++) {
if (t == channelTileCount - 1 && channelTileRest != 0)
sizes[1] = rewriter.getIndexAttr(channelTileRest);
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides);
}
}
}
}
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
ConversionPatternRewriter& rewriter) {
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
inputTiles[t][x][y] =
indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
}
}
}
return std::nullopt;
}
LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
const size_t inputTilesCount,
const size_t lastInputTileDimension,
TensorType inputShape,
TensorType outputShape,
Value reshapeInput,
ConversionPatternRewriter& rewriter) {
// Only support reshape between an image and a vector (i.e. flatten)
if (inputShape.getRank() != 4 || outputShape.getRank() != 2) {
return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(),
"resolveVecInputTiles only supports reshapes from 4D to 2D tensors");
}
/*
* From a 4D tensor <N, C, W, H> to a 2D tensor <N, C*H*W>
*/
auto N = inputShape.getDimSize(0);
auto C = inputShape.getDimSize(1);
auto H = inputShape.getDimSize(2);
auto W = inputShape.getDimSize(3);
assert(N == 1 && "Only support N = 1 for image tensors");
for (size_t i = 0; i < inputTilesCount; i++) {
auto c = (i / (H * W)) % C;
// TODO: Is this correct? Or should I invert h and w?
auto w = (i / H) % W;
auto h = i % H;
Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter);
// Assert the shape of the tile, and reshape it
auto curTileShape = cast<TensorType>(curTile.getType());
assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?");
assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?");
assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?");
assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?");
// Reshape this pixel tensor into a vector, for compatibility with the
// rest
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
Value shapeTensor =
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
size_t coreIndex = i / crossbarCountInCore;
inputTiles[coreIndex].push_back(reshapedCurTile);
}
return success();
}
std::pair<size_t, size_t> kernel_get_start_and_end(
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) {
int64_t firstValid = std::ceil(static_cast<float>(pad) / dilation) * dilation - pad;
int64_t start = std::max(firstValid, out_pos * stride - pad);
int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad);
assert(start >= 0 && "Start position must be non-negative.");
assert(end >= 0 && "End position must be non-negative.");
return std::make_pair(start, end);
}
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) {
auto oldSegmentSizes = wcomputeOp->getAttrOfType<DenseI32ArrayAttr>(wcomputeOp.getOperandSegmentSizesAttrName());
auto newSegmentSizes =
DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment});
wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes);
}
int getResultIndex(Operation* op, Value v) {
int resultNumber = -1;
for (auto result : op->getResults()) {
if (result == v) {
resultNumber = result.getResultNumber();
break;
}
}
assert(resultNumber >= 0 && "Value not found in given operation's results.");
return resultNumber;
}
}; // namespace onnx_mlir

View File

@@ -0,0 +1,262 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "llvm/Support/LogicalResult.h"
#define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
using namespace mlir;
namespace onnx_mlir {
const StringRef REPLICATION_ATTR_NAME = "replication_factor";
using HSliceId = size_t;
using CoreId = size_t;
enum class MapOperations {
None,
ONNXSoftmaxOp,
ONNXReluOp,
ONNXLeakyReluOp,
ONNXExpOp
};
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(const ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(const ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(const ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(const ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(const ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); }
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc);
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc);
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter);
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input);
/**
* Unpacks an optional pair vector into two size_t values.
*
* @param valuesArray The optional `mlir::ArrayAttr` containing the pair of
* values.
* @param value1 The reference to the first `size_t` variable to store the
* unpacked value.
* @param value2 The reference to the second `size_t` variable to store the
* unpacked value.
*/
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2);
/**
* Unpacks the optional pads vector.
*
* @param valuesArray The optional array attribute containing the values.
* @param pad_x The output variable to store the value of pad_x.
* @param pad_y The output variable to store the value of pad_y.
* @param rewriter The rewriter to notify failure
*
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
*/
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
/**
* Tiles the image tensor by channel.
*
* This function takes an image tensor and tiles it into smaller tiles based on
* the channel dimension. The size of each tile is specified by the tileSize
* parameter.
*
* @param imageTensor The input image tensor (NxCxWxH) to be tiled.
* @param tiles The output tiles vector to store the tiled image tensors.
* @param tileSize The size of each tile.
* @param rewriter The ConversionPatternRewriter used for creating operations.
*/
void tileImageTensorByChannel(Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
size_t tileSize,
ConversionPatternRewriter& rewriter);
/**
* Creates an ImgConcatOp based on the given tiles.
*
* This function takes a 3-dimensional vector `outputTiles` representing the
* tiles to concatenate. The tiles are indexed by [tile][x][y].
*
* @param outputTiles The tiles to concatenate.
* @param rewriter The ConversionPatternRewriter used for creating the
* ImgConcatOp.
* @param loc The location of the operation.
* @param outputType The type of the output tensor.
*
* @return The created ImgConcatOp.
*/
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
ConversionPatternRewriter& rewriter,
Location& loc,
Type outputType);
/**
* @brief Verifies if the given input coordinates and padding values are within
* the bounds of the input tensor.
*
* @param input_w The width of the input tensor.
* @param input_h The height of the input tensor.
* @param inX The X-coordinate of the input.
* @param inY The Y-coordinate of the input.
* @param pad_x The padding value in the X-direction.
* @param pad_y The padding value in the Y-direction.
* @return LogicalResult Returns success if the coordinates and padding are
* within bounds, failure otherwise.
*/
LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
/**
* Resolves the tiling of the input tensor into smaller tiles.
*
* This function takes a whole input tensor and tiles it into smaller tiles
* using the provided parameters. The resulting tiles are stored in the
* `inputTiles` vector.
* Input tiles need to be indexed by:
* a. Channel Tile
* b. Pixel `x` position
* c. Pixel `y` position
* For example: inputTiles[channelTile][x][y]
*
* @param wholeInputTensor The whole input tensor to be tiled.
* @param inputTiles A vector of vectors of vectors of Values representing the
* tiles of the input tensor. The outermost vector represents
* the channels, the middle vector represents the rows, and
* the innermost vector represents the columns of the tiles.
* @param channelTileCount The number of tiles for the `channel` axis.
* @param channelTileRest The size of the last channelTile. Set as 0 if tiles
* fit exactly
* @param input_w The width of the input tensor.
* @param input_h The height of the input tensor.
* @param rewriter The ConversionPatternRewriter used for creating operations.
*
* @return std::optional<llvm::Twine> An error message if the input tensor could
* not be resolved into tiles.
*/
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
mlir::ConversionPatternRewriter& rewriter);
/**
* Computes the boundaries of an image kernel application.
*
* @param out_pos The position of the output element.
* @param input_width The width of the input image.
* @param krn_width The width of the kernel.
* @param stride The stride value.
* @param dilation The dilation value.
* @param pad The padding value.
* @return A pair of size_t values representing the start and end positions of
* the kernel application.
*/
std::pair<size_t, size_t> kernel_get_start_and_end(
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad);
/**
* @brief Increment the `operandSegmentSizes` in the WeightedCompute operation
* for the `inputs` operand.
*
* This function increments the size of the `inputs` operand segment in the
* `operandSegmentSizes` of the given WeightedCompute operation by the specified
* increment. This is necessary when new operands are programmatically added to
* the WeightedCompute operation.
*
* @param wcomputeOp The WeightedCompute operation whose `operandSegmentSizes`
* is to be incremented.
* @param increment The value by which to increment the `inputs` operand segment
* size.
*/
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment);
/**
* @brief Finds the result index of the given operation that produces the
* specified value.
*
* This function takes an operation and a value, and returns the index of the
* result of the operation that corresponds to the given value.
*
* @param op Operation whose result index is to be found.
* @param v The value for which the result index is to be determined.
* @return The index of the result of the operation that produces the specified
* value.
*/
int getResultIndex(Operation* op, Value v);
}; // namespace onnx_mlir

View File

@@ -0,0 +1,131 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "ONNXToSpatialPass.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp module = getOperation();
MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(module);
func::FuncOp funcOp = *module.getOps<func::FuncOp>().begin();
if (annotateReplication(funcOp, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n";
signalPassFailure();
return;
}
ConversionTarget target(*ctx);
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
target.addIllegalOp<ONNXAveragePoolOp>();
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
if (useExperimentalConvImpl) {
populateExperimentalTilingConvOpPattern(patterns, ctx);
populateExperimentalPoolingTilingPattern(patterns, ctx);
populateGemmToConvConversionPattern(patterns, ctx);
}
else {
populateTilingConvOpPattern(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populateTilingGemmOpPattern(patterns, ctx);
}
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx);
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
return;
}
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) {
int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
signalPassFailure();
return;
}
}
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp);
// Dump to file for debug
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/spatial.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *module;
os.flush();
file.close();
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); });
if (isAlwaysWeight)
constantOp->setAttr("weightAlways", UnitAttr::get(ctx));
});
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,34 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
using namespace mlir;
extern bool haveSameStaticShape(Value lhs, Value rhs);
namespace spatial {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,40 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateTilingGemmOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateTilingConvOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populatePoolingTilingPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateDistributeReducePattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateFoldComputePattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateONNXConcatToTensorConcatPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateRemoveUnusedHelperOpsPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateReduceMeanConversionPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
// Experimental patterns.
void populateExperimentalTilingConvOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateGemmToConvConversionPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateExperimentalPoolingTilingPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
} // namespace onnx_mlir

View File

@@ -0,0 +1,31 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext *ctx) : OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
ONNXConcatOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
return success();
}
};
void populateONNXConcatToTensorConcatPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,34 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}
void initialize() { this->setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,119 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include <queue>
using namespace mlir;
namespace onnx_mlir {
/**
* @brief Structure that describes the replication of a convolution operation,
* along the image height axis.
*/
struct ConvReplication {
ONNXConvOp convOp; // Convolution operation
size_t input_w; // Width of the input image
size_t replicationFactor; // Replication factor on the image height axis
size_t coresNeededPerReplica; // Number of cores needed for each replica
friend bool operator<(const ConvReplication& a, const ConvReplication& b) {
return a.input_w / a.replicationFactor < b.input_w / b.replicationFactor;
}
ConvReplication(ONNXConvOp convOp, size_t input_w, size_t replicationFactor, size_t coresNeededPerReplica)
: convOp(convOp),
input_w(input_w),
replicationFactor(replicationFactor),
coresNeededPerReplica(coresNeededPerReplica) {}
};
LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter) {
if (coresCount == -1) {
// No need for annotation, implicitly set replication to 1
return success();
}
std::priority_queue<struct ConvReplication> convOpsReplicationQueue;
size_t minimumCores = 0;
for (auto& op : funcOp.getFunctionBody().begin()->getOperations()) {
if (auto convOp = dyn_cast<ONNXConvOp>(op)) {
// Convolution layer
Value X = convOp.getX(), W = convOp.getW();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
ShapedType wShape = mlir::cast<ShapedType>(W.getType());
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
auto neededCores = ceilIntegerDivide(neededXbars, crossbarCountInCore.getValue());
minimumCores += neededCores;
convOpsReplicationQueue.emplace(convOp, input_w, 1, neededCores);
}
else if (auto gemmOp = dyn_cast<ONNXGemmOp>(op)) {
// Fully connected layer
auto matrixTensorShape = cast<ShapedType>(gemmOp.getB().getType());
auto inputSize = matrixTensorShape.getDimSize(0);
auto outputSize = matrixTensorShape.getDimSize(1);
if (gemmOp.getTransB())
std::swap(inputSize, outputSize);
const size_t inputTilesCount = ceilIntegerDivide(inputSize, crossbarSize.getValue());
const size_t outputTilesCount = ceilIntegerDivide(outputSize, crossbarSize.getValue());
// Each output tile is computed by `coresPerOutputTile` cores. The
// entire input is given to each of these cores.
const size_t coresPerOutputTile = ceilIntegerDivide(inputTilesCount, crossbarCountInCore.getValue());
auto neededCores = coresPerOutputTile * outputTilesCount;
minimumCores += neededCores;
}
}
if (static_cast<size_t>(coresCount) < minimumCores) {
return funcOp->emitError("Not enough cores for this network: ")
<< minimumCores << " cores needed, but only " << static_cast<size_t>(coresCount) << " available.";
}
size_t availableCores = static_cast<size_t>(coresCount) - minimumCores;
// Consume all the elements in the queue
while (!convOpsReplicationQueue.empty()) {
auto convOpReplication = convOpsReplicationQueue.top();
convOpsReplicationQueue.pop();
// Check if we can replicate this convolution (e.g. we have enough cores)
if (availableCores > convOpReplication.coresNeededPerReplica * (convOpReplication.replicationFactor + 1)) {
// We can replicate this convolution: increment replicationFactor and put
// back in queue
availableCores -= convOpReplication.coresNeededPerReplica;
convOpReplication.replicationFactor++;
convOpsReplicationQueue.push(convOpReplication);
}
else {
// Cannot replicate this convolution anymore, annotate the operation
// with the replication factor
convOpReplication.convOp->setAttr(REPLICATION_ATTR_NAME,
rewriter.getI64IntegerAttr(convOpReplication.replicationFactor));
}
}
return success();
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,11 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
mlir::LogicalResult annotateReplication(
mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter);
} // namespace onnx_mlir

View File

@@ -0,0 +1,382 @@
#include "SpatialReducer.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <unordered_map>
#include <utility>
#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum)
#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum)
namespace onnx_mlir {
llvm::SmallPtrSet<Operation *, 16>
onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing(
ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value &)> processFun,
ConversionPatternRewriter &rewriter) {
assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp =
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result);
Value processedResult = processFun(result);
if (processedResult == result) {
// Sometimes we want processedResult to return the same value but do
// something else with it (e.g. in softmax we want to broadcast the value
// using a channel). In this case, we can just return the same value.
return resultNum;
}
yieldOp->insertOperands(yieldOp->getNumOperands(), processedResult);
return yieldOp.getNumOperands() - 1;
}
OpAndResNum SpatialReducer::applyReducePattern(
SmallVector<ComputeAndResNum> &computeOpsAndResNum,
std::function<Value(const Value &, const Value &)> reduce,
std::function<Value(const Value &)> preprocess,
std::function<Value(const Value &)> postprocess) {
if (preprocess) {
for (auto &computeOpAndResNum : computeOpsAndResNum) {
GET_RES_NUM(computeOpAndResNum) =
applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
}
}
// It is possible that `computeOpsAndResNum` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation *, Value> lastValueForCompute;
for (auto &computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp =
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute =
yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation());
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() &&
lastWithinComputeValue.getDefiningOp());
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(
lastWithinComputeValue.getDefiningOp())) {
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
} else {
rewriter.setInsertionPointAfterValue(valueWithinCompute);
}
valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute);
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
}
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
}
// Now, reconstruct from the map the computeOpsAndResNum list
computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto &entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it
auto yieldOp =
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false;
for (auto &use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) {
// If the value is already used by the yieldOp, we can just use it
computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()});
yieldOpUseFound = true;
break;
}
}
if (yieldOpUseFound) {
continue;
}
// If this result is not used within a yieldOp, then add it
auto resultNum = yieldOp->getNumOperands();
yieldOp->insertOperands(resultNum, valueWithinCompute);
computeOpsAndResNum.push_back({computeOp, resultNum});
}
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the computeOpsAndResNum list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
while (computeOpsRef.size() > 1) {
SmallVector<ComputeAndResNum> nextComputeOps;
nextComputeOps.reserve(computeOpsRef.size() / 2);
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
auto [firstCompute, firstResultNum] = computeOpsRef[i];
auto [secondCompute, secondResultNum] = computeOpsRef[i + 1];
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstCompute, secondCompute);
std::swap(firstResultNum, secondResultNum);
}
// We do not immediately alter the computeOps results/operands, instead we
// do it in a delayed manner, to avoid invalidating the references to the
// computeOps (which must be replaced by a cloned ComputeOp when changing
// the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(
firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp
Block &secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(
yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>(
secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum =
secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield =
cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation
rewriter.setInsertionPoint(secondYield);
Value reduced = reduce(formerRes2, formerRes1);
// Unfortunately, it is not possible to update the result in place,
// because we may have already referenced it by <computeOp, resultNum>
// outside of this function, thus replacing it would invalidate the
// reference. Therefore, we need to append a new result to the yieldOp,
// and then at a later stage update the computeOp accordingly.
// Add `reduced` to the second yieldOp
auto secondYieldOperandNum = secondYield.getNumOperands();
secondYield->insertOperands(secondYieldOperandNum, reduced);
secondResultNum = secondYieldOperandNum;
// We should also add an entry for updating the results of the last
// operation (the one which never becomes a `firstCompute`): because it is
// not tracked by reducerChanges as `fromOp`
reducerChanges.push_back({firstCompute.getOperation(), firstResultNum,
secondCompute.getOperation(), secondComputeOperandNum});
nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum));
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (computeOpsRef.size() % 2 == 1) {
nextComputeOps.push_back(computeOpsRef.back());
}
// Replace the inputOps list with the new one.
computeOpsRef =
llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
}
assert(computeOpsRef.size() == 1 &&
"Internal error: expected a single input at this point.");
auto finalComputeAndResNum = computeOpsRef[0];
// Force the update of the results of this computeOp, when finalizing
computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum));
if (postprocess) {
GET_RES_NUM(finalComputeAndResNum) =
applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
}
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(),
GET_RES_NUM(finalComputeAndResNum));
}
void SpatialReducer::finalizeReduceUpdates() {
assert(reducesFinalized == false && "Cannot finalize two times.");
reducesFinalized = true;
// First, add the results to the computeOps
for (auto &reduceChange : reducerChanges) {
updateResultsOfCompute(reduceChange.fromOp);
}
for (auto &c : computeOpNeedingResUpdate) {
updateResultsOfCompute(c.getOperation());
}
for (auto &reducerChange : this->reducerChanges) {
auto fromOp = reducerChange.fromOp;
auto toOp = reducerChange.toOp;
auto fromOpResNum = reducerChange.fromOpResNum;
auto toOpOperandNum = reducerChange.toOpOperandNum;
auto fromComputeOp = opToReplacedCompute[fromOp];
assert(fromComputeOp && "fromOp should have been mapped before!");
// toComputeOp could be the existing pointer, or we have to remap it with
// `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp) {
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
}
assert(toComputeOp != fromComputeOp &&
"Oops should have caught this earlier!");
assert(toComputeOp->getNumOperands() == toOpOperandNum &&
"toOpOperandNum should be the last operand of toComputeOp, are the "
"operations in the right order?");
// Add the new operand to `toComputeOp`
auto fromResult = fromComputeOp.getResult(fromOpResNum);
toComputeOp->insertOperands(toOpOperandNum, fromResult);
incrementWeightedComputeInputsSegmentSize(toComputeOp, 1);
}
}
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum &opAndResNum) {
assert(reducesFinalized &&
"Cannot create resolve values before finalizing the reduce updates.");
Operation *opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end()) {
opToCast = it->second;
} else {
opToCast = opAndResNum.first;
}
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second);
}
void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again
return;
}
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp);
auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp =
cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map
opToReplacedCompute[oldComputeOp.getOperation()] = oldComputeOp;
return;
}
// Add the results by inspecting its YieldOp
auto newResultTypes = yieldOp.getOperandTypes();
// Create a new ComputeOp with the new result type, but same operands
rewriter.setInsertionPoint(oldComputeOp);
auto newComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(oldComputeOp->getLoc(),
newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
auto newComputeOpNum = newComputeOp->getNumOperands();
assert(oldComputeOpNum == newComputeOpNum);
// Since we replaced the old ComputeOp with a new one, we need to replace
// all its results' uses
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
Value oldResult = oldComputeOp.getResult(i);
Value newResult = newComputeOp.getResult(i);
// Replace the uses, except the uses of the compute ops which got deleted
// previously
rewriter.replaceAllUsesExcept(oldResult, newResult, oldComputeOpsReplaced);
}
// Finally, erase the old computeOp and update the map
opToReplacedCompute[oldComputeOp.getOperation()] = newComputeOp;
oldComputeOpsReplaced.insert(oldComputeOp.getOperation());
rewriter.setInsertionPoint(oldComputeOp);
rewriter.eraseOp(oldComputeOp);
}
Value SpatialReducer::createImgConcatOp(
SmallVector<SmallVector<SmallVector<OpAndResNum>>> &outputTiles,
Location &loc, Type outputType) {
assert(reducesFinalized &&
"Cannot create ImgConcatOp before finalizing the reduce updates.");
// outputTiles are indexed like this: [channelTile][x][y]
auto tilesCount = outputTiles.size();
auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(tilesCount,
SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++)
for (size_t y = 0; y < height; y++)
remappedOutputTiles[t][x][y] =
resolveValueFromOpAndResNum(outputTiles[t][x][y]);
return ::onnx_mlir::createImgConcatOp(
remappedOutputTiles, rewriter, loc, outputType);
}
OpAndResNum SpatialReducer::applyAddMapReduction(
SmallVector<ComputeAndResNum> &computeOps,
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp) {
std::function<Value(const Value &)> postprocessing = nullptr;
if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) {
Value mapOperand = a;
if (biasTile) {
mapOperand = rewriter.create<spatial::SpatVAddOp>(
a.getLoc(), a.getType(), a, biasTile);
}
return createMapOperation(rewriter, mapOp, mapOperand);
};
}
return this->applyReducePattern(
computeOps,
[&](Value a, Value b) {
return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
},
/* preprocess = */ nullptr, postprocessing);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,83 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
namespace onnx_mlir {
using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange {
Operation *fromOp;
unsigned int fromOpResNum;
Operation *toOp;
unsigned int toOpOperandNum;
};
using OpAndResNum = std::pair<Operation *, ResNum>;
class SpatialReducer {
public:
SpatialReducer(ConversionPatternRewriter &rewriter) : rewriter(rewriter) {}
OpAndResNum applyReducePattern(
SmallVector<ComputeAndResNum> &computeOpsAndResNum,
std::function<Value(const Value &, const Value &)> reduce,
std::function<Value(const Value &)> preprocess,
std::function<Value(const Value &)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum> &computeOps,
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp);
void finalizeReduceUpdates();
~SpatialReducer() {
if (!reducesFinalized) {
finalizeReduceUpdates();
}
}
Value createImgConcatOp(
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>
&outputTiles,
Location &loc, Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum &opAndResNum);
private:
[[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value &)> processFun,
ConversionPatternRewriter &rewriter);
/**
* @brief Update the results of a ComputeOp.
*
* This function updates the results of a ComputeOp by taking a look at the
operands of its yieldOp.
* If the ComputeOp was replaced, it updates `opToReplacedCompute` with the
replaced ComputeOp.
*
* @param computeOp The ComputeOp to update the results of.
*/
void updateResultsOfCompute(Operation *computeOp);
ConversionPatternRewriter &rewriter;
bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized
SmallVector<SpatialReducerChange, 4> reducerChanges;
// List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation *, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced;
};
} // namespace onnx_mlir

View File

@@ -0,0 +1,53 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include <cassert>
namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(
map<long, map<long, SmallVector<Value>>> weights)
: weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin();
SmallVector<Value> &values = it->second.begin()->second;
long inputTile = it->first;
long outputTile = it->second.begin()->first;
size_t n = std::min(amount, values.size());
crossbarsUsed += n;
SmallVector<Value> result;
result.assign(values.begin(), values.begin() + n);
if (n < values.size()) {
values.erase(values.begin(), values.begin() + n);
} else {
it->second.erase(outputTile);
if (it->second.empty()) {
weights.erase(inputTile);
}
}
return {inputTile, outputTile, crossbarsUsed - n, result};
}
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
crossbarsUsed = 0;
SmallVector<TaggedWeights> result;
size_t remaining = n;
while (remaining > 0 && !weights.empty()) {
auto group = popGroup(remaining);
result.push_back(group);
remaining -= group.weights.size();
}
return result;
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,46 @@
#pragma once
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include <map>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A helper struct to store a group of weights.
*
*/
struct TaggedWeights {
long inputTile;
long outputTile;
size_t startingCrossbarIndex;
SmallVector<Value> weights;
};
/**
* @brief A helper class to subdivide weights into groups.
*
* Weights are stored as a map of maps of SmallVectors. The outer map is indexed
* by input tile, the inner map is indexed by output tile, and the SmallVector
* contains the weights for the filter. This class allows us to extract groups
* of weights from the map until we've extracted a certain number of elements,
* namely as many as we need to fill a compute unit.
*/
class WeightSubdivider {
private:
map<long, map<long, SmallVector<Value>>> weights;
size_t crossbarsUsed = 0;
TaggedWeights popGroup(size_t amount);
public:
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights);
bool isEmpty() const;
SmallVector<TaggedWeights> popGroups(size_t n);
};
} // namespace onnx_mlir

View File

@@ -0,0 +1,14 @@
add_onnx_mlir_rewriter(SpatialToGraphviz)
add_onnx_mlir_library(OMSpatialToGraphviz
SpatialToGraphviz.cpp
LINK_LIBS PUBLIC
OMCompilerOptions
OMPIMCommon
OMONNXOps
SpatialOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -0,0 +1,283 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h"
#define FORMAT_OPERATION(op) \
'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) \
llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
using namespace mlir;
namespace onnx_mlir {
namespace {
struct SpatialToGraphvizPass
: public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass)
StringRef getArgument() const override {
return "convert-spatial-to-graphviz";
}
StringRef getDescription() const override {
return "Lower ONNX ops to Spatial ops.";
}
SpatialToGraphvizPass(raw_ostream &os = llvm::errs()) : os(os) {}
SpatialToGraphvizPass(const SpatialToGraphvizPass &pass)
: SpatialToGraphvizPass(pass.os) {}
void runOnOperation() final;
private:
raw_ostream &os;
/**
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
* 1. Input nodes (block arguments)
* 2. Operations
* 3. Edges between yield (output) and its users
*
* @param op The spatial::SpatWeightedCompute to draw the subgraph for.
* @param computeNum The number of the compute operation.
*/
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute"
<< computeNum << "\";\n"
<< "\t\tstyle=filled;\n"
<< "\t\tcolor=lightblue;\n";
Block &block = op.getBody().front();
// Inputs
size_t inputNum = 0;
for (BlockArgument &input : block.getArguments()) {
auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum);
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum
<< "\",shape=box];\n";
for (auto userOp : input.getUsers()) {
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
inputNum++;
}
// Iterate operations
for (auto &childOp : block.getOperations()) {
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\""
<< childOp.getName() << "\"];\n";
drawEdgesFromOpToItsUsers(&childOp);
}
os << "\t}\n";
// Draw edges from the yield to the users of this computeOp
Operation *yieldOp = block.getTerminator();
if (!isa<spatial::SpatYieldOp>(yieldOp)) {
yieldOp->emitError("Terminator of block must be YieldOp ???");
signalPassFailure();
return;
}
for (auto computeOpResult : op->getResults()) {
for (auto &computeOpUse : computeOpResult.getUses()) {
auto toOp = FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber());
os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n";
}
}
}
/**
* @brief Draws the subgraph for a concatOp.
*
* This function draws a subgraph for a concatOp. The subgraph consists of a
* node for each input of the concatOp, as well as an output node. Edges are
* created from the output node to each user of the concatOp.
*
* @param concatOp The concatOp for which the subgraph is drawn.
* @param concatOpNum The number of the concatOp.
*/
void drawConcatOpSubgraph(Operation *concatOp, size_t concatOpNum) {
os << "\tsubgraph clusterconcat" << concatOpNum
<< " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
<< "\t\tstyle=filled;\n"
<< "\t\tcolor=orange;\n";
// Inputs
size_t inputNum = 0;
for (Value input : concatOp->getOperands()) {
auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum);
os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n";
for (auto userOp : input.getUsers()) {
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
inputNum++;
}
// Output
os << "\t\t" << FORMAT_OPERATION(concatOp) << " [label=Out];\n";
os << "\t}\n";
// Edges from output to users
for (auto &computeOpUse : concatOp->getResult(0).getUses()) {
os << "\t" << FORMAT_OPERATION(concatOp) << " -> "
<< FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n";
}
}
/**
* Draws the ExtractSliceOp in the graph visualization.
*
* This function takes a tensor::ExtractSliceOp and adds the corresponding
* node and edges to the graph visualization. It creates a node with the
* label as the static offsets attribute of the sliceOp, and connects it to
* the compute operations that use the result of the sliceOp.
*
* @param sliceOp The tensor::ExtractSliceOp to be drawn in the graph
* visualization.
*/
void drawExtractSliceOp(tensor::ExtractSliceOp sliceOp) {
auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0);
os << "\t" << nodeId << " [label=\"Slice: ";
sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lawngreen];\n";
for (auto &computeOpUse : sliceOp.getResult().getUses()) {
os << "\t" << nodeId << " -> "
<< FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n";
}
}
void drawBiasTileOp(tensor::ExtractSliceOp sliceOp) {
auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0);
os << "\t" << nodeId << " [label=\"Bias: ";
sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lightpink];\n";
for (auto user : sliceOp.getResult().getUsers()) {
os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n";
}
}
/**
* Draws edges from the given operation to its users.
*
* @param fromOp The operation from which the edges are drawn.
*/
void drawEdgesFromOpToItsUsers(mlir::Operation *fromOp) {
for (auto result : fromOp->getResults()) {
for (auto userOp : result.getUsers()) {
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> "
<< FORMAT_OPERATION(userOp) << ";\n";
}
}
}
/**
* Draws input node and edges for the given `funcOp`.
*
* @param funcOp The `funcOp` for which to draw input nodes and edges.
*/
void drawInputNodesAndEdges(func::FuncOp &funcOp) {
os << "\tinput [label=\"Module Input\",color=green];\n";
size_t funcOpArgNum = 0;
for (BlockArgument &arg : funcOp.getArguments()) {
for (auto &useOp : arg.getUses()) {
os << "\tinput -> "
<< FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber())
<< "[label=" << funcOpArgNum << "];\n";
}
funcOpArgNum++;
}
}
};
void SpatialToGraphvizPass::runOnOperation() {
ModuleOp module = getOperation();
// Get the first OP, must be a FuncOp
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
if (!func) {
module->emitError("No FuncOp found in the begin of module");
signalPassFailure();
}
os << "digraph G {\n"
<< "\tnode [style=filled,color=white];\n";
size_t computeNum = 0;
size_t concatNum = 0;
// Iterate over the ComputeOps within FuncOp:
// 1. Print their subgraph
// 2. Print the edges from its inputs to its outputs
for (Operation &op : func.getOps()) {
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
drawComputeOpSubgraph(computeOp, computeNum++);
} else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
drawConcatOpSubgraph(concatOp, concatNum++);
} else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
drawConcatOpSubgraph(imgConcatOp, concatNum++);
} else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
if (producerOp) {
// Skip extractSliceOp if producer is constant weights (ONNXConstantOp)
if (llvm::isa<ONNXConstantOp>(producerOp)) {
continue;
}
// If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect
// directly to its user, which is not a ComputeOp argument.
if (llvm::isa<tosa::ReshapeOp>(producerOp)) {
drawBiasTileOp(extractSliceOp);
continue;
}
}
drawExtractSliceOp(extractSliceOp);
}
}
// Draw input node, and edges to it users
drawInputNodesAndEdges(func);
// Draw output node (use the return Operation - argument number=0 - as nodeId)
auto returnOp = func.getBody().front().getTerminator();
os << '\t' << FORMAT_ARGUMENT(returnOp, 0)
<< " [label=\"Module Output\",color=green];\n";
os << "}\n";
}
} // namespace
std::unique_ptr<Pass> createSpatialToGraphvizPass() {
return std::make_unique<SpatialToGraphvizPass>();
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,21 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPIMIncGen)
add_onnx_mlir_library(OMSpatialToPIM
SpatialToPIMPass.hpp
SpatialToPIMPass.cpp
SpatialToPIMCommon.cpp
DEPENDS
SpatialToPIMIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPIMCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -0,0 +1,28 @@
#ifndef SPATIAL_TO_PIM
#define SPATIAL_TO_PIM
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
#endif // OP_BASE
def spatToPimVMMOp : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVMOp : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVAddOp : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
#endif // SPATIAL_TO_PIM

View File

@@ -0,0 +1,97 @@
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include <cstddef>
#include "SpatialToPIMCommon.hpp"
using namespace llvm;
namespace onnx_mlir {
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
[1, 10, 3, 4] inputShape
[0, 2, 1, 3] offsets
acc = 1
---
ret = 3
acc = 4
---
ret = 3 + 4 * 1 = 7
acc = 12
---
ret = 7 + 12 * 2 = 31
acc = 120
---
ret = 31 + 120 * 0 = 31
acc = 120
*/
size_t returnValue = 0;
auto sliceOffsets = sliceOp.getStaticOffsets();
auto inputDimSizes = inputShape.getShape();
assert(sliceOffsets.size() == inputDimSizes.size());
size_t accumulatedDimensionSize = 1;
// Reverse iterate the two vectors
for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) {
auto curSliceOffset = std::get<0>(it);
auto curInputDimSize = std::get<1>(it);
returnValue += accumulatedDimensionSize * curSliceOffset;
accumulatedDimensionSize *= curInputDimSize;
}
return returnValue;
}
Operation* getEarliestUserWithinBlock(Value value) {
auto users = value.getUsers();
assert(!users.empty());
Operation* earliestUser = *users.begin();
for (auto curUser : users)
if (curUser->isBeforeInBlock(earliestUser))
earliestUser = curUser;
return earliestUser;
}
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
return {operand, std::distance(operand.use_begin(), operand.use_end())};
});
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
Value result = operation->getResult(0);
auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands =
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())
return *bestOperand;
auto resultShapedType = cast<ShapedType>(resultType);
rewriter.setInsertionPoint(operation);
return rewriter.create<tensor::EmptyOp>(
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,108 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
template <class T>
size_t rangeLength(const iterator_range<T> range) {
return std::distance(range.begin(), range.end());
}
/**
* Retrieves the earliest operation that uses the given value within the value's
* block.
*
* @param value The value for which to find the earliest user operation.
* @return The earliest user operation that uses the given value within the
* current block.
*/
Operation* getEarliestUserWithinBlock(Value value);
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
const ArrayRef<int64_t> offsets,
const ArrayRef<int64_t> sizes,
const ArrayRef<int64_t> strides) {
// Check that all strides are 1
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
// Check offsets from right to left:
// The first offset_n at position n different from 0:
// - limits all sizes to the left to 1
// - limits size_n to dimension_n - offset_n
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
// Check sizes from right to left:
// The first size_n at position n different from shape_n limits all sizes to the left to 1
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,491 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_os_ostream.h"
#include <cassert>
#include <filesystem>
#include <fstream>
#include <string>
#include <utility>
#include "SpatialToPIMPass.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void SpatialToPIMPass::runOnOperation() {
coreId = 0;
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (!funcOp)
llvm_unreachable("No FuncOp found in the begin of module");
IRRewriter rewriter(&getContext());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter);
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp);
runOnReceiveOp(receiveOp, rewriter);
}
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
operationsToRemove.push_back(computeOp);
runOnComputeOp(computeOp, rewriter);
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
// Remove all ComputeOps
for (auto opToRemove : llvm::reverse(operationsToRemove)) {
if (!opToRemove->use_empty()) {
opToRemove->dump();
for (auto user : opToRemove->getUsers())
user->dump();
assert(false && "opToRemove should be unused at this point");
}
rewriter.eraseOp(opToRemove);
}
// Dump to file for debug
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/pim.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
// If this result has no uses, then just skip it
if (result.use_empty())
continue;
auto yieldType = cast<TensorType>(yieldValue.getType());
/*
* Here we assume that ReturnOp are only reachable by the following patterns:
*
* 1)
* %0 = spat.compute([...])
* [%0 has one user, which is a ConcatOp]
* %1 = tensor.concat(%0)
* [%1 has one user, which is a ReturnOp]
* return %1
*
* 2)
* %0 = spat.compute([...])
* [%0 has one user, which is a ReturnOp]
* return %0
*
* If the IR is like 2), then we can store the tensor to the output global memory location
*/
auto resultUses = result.getUses();
auto numResultUses = rangeLength(resultUses);
if (numResultUses == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t offset = 0;
size_t numElements = yieldType.getNumElements();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
// Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>(loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(numElements * elementSize));
continue;
}
if (isa<tensor::ConcatOp>(resultUser) || isa<spatial::SpatImgConcatOp>(resultUser)) {
auto concatOp = resultUser;
auto concatValue = concatOp->getResult(0);
auto concatUses = concatValue.getUses();
auto numConcatUses = rangeLength(concatUses);
if (numConcatUses == 1) {
OpOperand& concatUse = *concatUses.begin();
Operation* concatUser = concatUse.getOwner();
if (isa<func::ReturnOp>(concatUser)) {
size_t concatIndexInReturn = concatUse.getOperandNumber();
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
size_t offset = 0;
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
offset += cast<ShapedType>(operand.getType()).getNumElements() * cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
// Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>(
loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
continue;
}
}
}
}
// If this pattern was not found, then create a channel and send the value
// 1. Create a new ChannelOp
rewriter.setInsertionPoint(computeOp);
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
// 2. Receive value through the channel
// If this result is used by more than one user, then use a "Broadcast"
// channel operation. However, there is a special case: we have a single
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
// will detect this case and update `useBroadcastOp` accordingly.
bool useBroadcastOp = (numResultUses > 1);
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
// 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue);
if (useBroadcastOp)
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue);
else
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue);
}
// Use `HaltOp` instead of `YieldOp`
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
// Replace `spat.compute` with `pim.core`
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
rewriter.create<PimHaltOp>(computeOp.getLoc());
}
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return;
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
if (!dpsDefiningOp)
return;
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
return;
Value tiedValue = tiedOperand->get();
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
tiedValue.setType(newType);
self(tiedValue, newType, self);
};
funcOp.walk([&](PimVMMOp vmmOp) {
auto outTensorOperand = vmmOp.getOutBuf();
auto resultTensor = vmmOp.getOutRes();
auto outShape = getTensorShape(outTensorOperand);
assert(isHVectorShape(outShape));
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
resultTensor.setType(newType);
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
}
});
}
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
}
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(valueToReplace.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>(
loc,
tensorType,
deviceTensor,
hostTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
};
// Replace input tensors with memRefs
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
BlockArgument tensorArg = funcOp.getArgument(i);
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front());
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
funcOp.eraseArgument(i);
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
unsigned numComputeWeights = computeOp.getWeights().size();
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
TypedValue<TensorType> tensorSource;
int64_t elementsOffset = 0;
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
assert("Extracting slice non-contiguous in memory"
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
for (size_t i = 0; i < sliceOffsets.size(); i++) {
int64_t partialOffset = sliceOffsets[i];
if (partialOffset != 0)
for (size_t j = i + 1; j < sourceShape.size(); j++)
partialOffset *= sourceShape[j];
elementsOffset += partialOffset;
}
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
sliceOpsToRemove.insert(sliceOp);
}
else
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Compute results must be transferred through channels via send/receive
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
}
}
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
}
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter) {
auto& computeBlock = computeOp.getRegion().front();
//(remember that WeightedCompute have weights as first operands, however these
// weights are not included in the block arguments. Thus, when indexing the
// block argument we need to remove the weights count)
auto computeWeightsCount = computeOp.getWeights().size();
auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount);
// Receive the tensor just before the first use of the value
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue;
if (useBroadcastOp)
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel);
else
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel);
blockArg.replaceAllUsesWith(receivedValue);
}
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter) {
auto sourceOpUses = channelSourceOp.getUses();
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
if (useBroadcastOp == false) {
// if useBroadcastOp is false, then sourceOp must have only one user
assert(rangeLength(sourceOpUses) == 1);
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
auto reshapeOpUses = reshapeOp.getOutput().getUses();
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
if (reshapeOpUsesCount > 1)
useBroadcastOp = true;
}
}
for (auto& resultUse : sourceOpUses) {
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
if (computeUser) {
replaceBlockArgumentWithRecvOp(
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
continue;
}
if (!computeUser) {
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
if (!reshapeOp) {
resultUse.getOwner()->dump();
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
}
// The tensorType now becomes the one of the reshapeOp
channelTensorType = reshapeOp.getResult().getType();
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
if (!computeUser)
llvm_unreachable("ReshapeOp users must be ComputeOps");
replaceBlockArgumentWithRecvOp(
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
}
// Remove the reshapeOp, so that the sourceOp has no users
operationsToRemove.push_back(reshapeOp);
}
}
}
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
for (auto it : llvm::enumerate(returnOp.getOperands())) {
Operation* returnOperand = it.value().getDefiningOp();
size_t orderWithinReturn = it.index();
rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
// If the operand is a concatenation operation and the returnOp was the only
// user of the returnOperand, we can safely remove it
if (isAConcatOp(returnOperand)) {
auto returnOperandUses = it.value().getUses();
if (rangeLength(returnOperandUses) == 0)
rewriter.eraseOp(returnOperand);
}
}
}
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter);
if (failed(sendOpOpt))
llvm_unreachable("ChannelReceiveOp has no matching SendOp");
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
auto tensorType = receiveOp.getType();
Value receiveRes = receiveOp.getResult();
// Check if the receiveOp value has more than one user
auto receiveUses = receiveRes.getUses();
auto receiveUsesCount = rangeLength(receiveUses);
assert(receiveUsesCount > 0);
bool useBroadcastOp = receiveUsesCount > 1;
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
if (useBroadcastOp) {
// When receiving, we actually noticed that the value has more than one
// user. This means that we need to get the replace the original SendOp with
// a BroadcastSendOp
rewriter.setInsertionPoint(sendOp);
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
}
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,60 @@
#pragma once
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace pim {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace pim
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<pim::SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,12 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
namespace spatial {
// TODO: Add here eventual patterns
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,2 @@
add_subdirectory(PIM)
add_subdirectory(Spatial)

View File

@@ -0,0 +1,15 @@
add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td)
add_onnx_mlir_library(PimOps
PimOps.cpp
Transforms/PimBufferizableOpInterface.cpp
DEPENDS
OMPimIncGen
LINK_LIBS PUBLIC
OMMlirDialects
MLIRIR
)

345
src/PIM/Dialect/PIM/Pim.td Normal file
View File

@@ -0,0 +1,345 @@
#ifndef PIM_DIALECT_H
#define PIM_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
def PimDialect : Dialect {
let name = "pim";
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
let cppNamespace = "::onnx_mlir::pim";
}
// Base class for Pim dialect operations. This operation inherits from the
// base `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class PimOp<string mnemonic, list<Trait> traits = []> :
Op<PimDialect, mnemonic, traits>;
def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
//===----------------------------------------------------------------------===//
// Communication operations
//===----------------------------------------------------------------------===//
def PimSendOp: PimOp<"send", []> {
let arguments = (ins
PimTensor: $src,
I32Attr: $size,
I32Attr: $targetCoreId
);
let assemblyFormat = [{
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
}];
}
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins
PimTensor: $dst,
I32Attr: $size,
I32Attr: $srcCoreId
);
let results = (outs
PimTensor: $out
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
}];
}
//===----------------------------------------------------------------------===//
// Core operations
//===----------------------------------------------------------------------===//
def PimCoreOp: PimOp<"core", [SingleBlock]> {
let regions = (region SizedRegion<1>:$body);
let arguments = (ins
Variadic<PimTensor>:$weights,
I32Attr: $coreId
);
let assemblyFormat = [{
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
}];
}
//===----------------------------------------------------------------------===//
// Memory Operations
//===----------------------------------------------------------------------===//
def PimConstantOp: PimOp<"constant", []> {
let description = [{
Allocate a constant value in global memory
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
PimTensor: $out
);
}
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from host memory into device memory
}];
let arguments = (ins
PimTensor: $deviceDst,
PimTensor: $hostSrc,
I32Attr: $deviceDstOffset,
I32Attr: $hostSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $deviceDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceDstMutable();
}
}];
let assemblyFormat = [{
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
}];
}
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from device memory into host memory
}];
let arguments = (ins
PimTensor: $hostDst,
PimTensor: $deviceSrc,
I32Attr: $hostDstOffset,
I32Attr: $deviceSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $hostDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostDstMutable();
}
}];
let assemblyFormat = [{
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
}];
}
//===----------------------------------------------------------------------===//
// Core.Compute operations
//===----------------------------------------------------------------------===//
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
Vector-matrix multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
let description = [{
Matrix-vector multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise max: c = max(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Apply filters to a tensor
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
PimTensor: $input,
PimTensor: $outBuf,
PimTensor: $accumBuf
);
let results = (outs
PimTensor: $outRes
);
let assemblyFormat = [{
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
}];
}
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Sum all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
}];
let arguments = (ins
PimTensor: $dividend,
PimTensor: $divisor,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise ReLU: c = max(a, 0)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise exp: c = exp(a)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimHaltOp: PimOp<"halt", [Terminator]> {
let description = [{
Halts the execution of the core
}];
let assemblyFormat = [{
attr-dict
}];
}
#endif // PIM_DIALECT_H

View File

@@ -0,0 +1,49 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
>();
}
#define POPULATE_DEPENDENCIES(OP_NAME) \
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVMaxOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp)
} // namespace pim
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"

View File

@@ -0,0 +1,18 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"

View File

@@ -0,0 +1,172 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace pim {
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc();
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
if (failed(deviceDstOpt))
return failure();
auto deviceDstMemRef = *deviceDstOpt;
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
if (failed(hostSrcOpt))
return failure();
auto hostSrcMemRef = *hostSrcOpt;
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceDstMemRef.getType(),
deviceDstMemRef,
hostSrcMemRef,
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
memCopyHostToDevOp.getHostSrcOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
};
struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst();
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
if (failed(globalDstOpt))
return failure();
auto globalDstMemRef = *globalDstOpt;
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
if (failed(localSrcOpt))
return failure();
auto localSrcMemRef = *localSrcOpt;
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
globalDstMemRef.getType(),
globalDstMemRef,
localSrcMemRef,
memCopyDevToHostOp.getHostDstOffsetAttr(),
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
Value readVal = uRead->get();
Value writeVal = uWrite->get();
if (writeVal != vmmOp.getOutBuf())
return false;
if (readVal == vmmOp.getVectorInput())
if (state.areEquivalentBufferizedValues(readVal, writeVal))
return true;
return false;
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
return success();
}
};
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
return success();
}
};
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
const AnalysisState& state,
ArrayRef<OpOperand*> opOperands) const {
return true;
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto vaddOp = cast<PimVAddOp>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
if (failed(aOpt))
return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
if (failed(bOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
});
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,14 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,15 @@
add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td)
add_onnx_mlir_library(SpatialOps
SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp
DEPENDS
OMSpatialIncGen
LINK_LIBS PUBLIC
MLIRIR
OMMlirDialects
)

View File

@@ -0,0 +1,355 @@
#ifndef SPATIAL_DIALECT_H
#define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
def SpatialDialect : Dialect {
let name = "spat";
let summary = "Dialect designed for deep learning computation in a spatial architecture";
let cppNamespace = "::onnx_mlir::spatial";
let useDefaultTypePrinterParser = 1;
}
class SpatOp<string mnemonic, list<Trait> traits = []> :
Op<SpatialDialect, mnemonic, traits>;
// TODO maybe remove and use AnyRankedTensor directly
def SpatTensor:
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<SpatialDialect, name, traits> {
let mnemonic = typeMnemonic;
}
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute operation, with constant weights already attached";
let arguments = (ins
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let assemblyFormat = [{
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
}];
}
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
let assemblyFormat = [{
$outputs attr-dict `:` type($outputs)
}];
}
//===----------------------------------------------------------------------===//
// Data movement operations
//===----------------------------------------------------------------------===//
def SpatChannelNewOp: SpatOp<"channel_new", []> {
let results = (outs
SpatChannelType:$new_channel
);
let builders = [
OpBuilder<(ins ), [{
$_state.addTypes(SpatChannelType());
}]>
];
let assemblyFormat = [{
attr-dict
}];
}
def SpatChannelSendOp: SpatOp<"channel_send", []> {
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
);
let assemblyFormat = [{
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
}];
}
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
let arguments = (ins
SpatChannelType: $channel
);
let results = (outs
SpatTensor: $data
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let arguments = (ins
SpatChannelType: $channel,
SpatTensor: $data
);
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let arguments = (ins
SpatChannelType: $channel
);
let results = (outs
SpatTensor: $data
);
}
//===----------------------------------------------------------------------===//
// Math operations
//===----------------------------------------------------------------------===//
def SpatConstantOp: SpatOp<"constant", []> {
let description = [{
"Constant value, should be used for weights and biases"
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
SpatTensor: $out
);
}
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatVAddOp: SpatOp<"vadd", []> {
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
def SpatVMulOp: SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
def SpatVDivOp: SpatOp<"vdiv", []> {
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor:$a,
SpatTensor:$b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
//TODO: remove
def SpatVSDivOp: SpatOp<"vsdiv", []> {
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
let arguments = (ins
SpatTensor:$dividend,
SpatTensor:$divisor
);
let results = (outs
SpatTensor:$output
);
}
def SpatSumOp: SpatOp<"sum", []> {
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
let arguments = (ins
SpatTensor: $input
);
let results = (outs
SpatTensor:$output
);
}
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
}
def SpatReluOp: SpatOp<"relu", []> {
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
}
def SpatVMaxOp: SpatOp<"vmax", []> {
let summary = "Element-wise max function";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
let description = [{
Applies a variable number of crossbar weights to a single large image tensor tile,
producing a corresponding output tile. This essentially encapsulates a big for loop
over all pixels in the input tile, where each pixel is multiplied by all the weights
in the operation.
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
SpatTensor: $input
);
let results = (outs SpatTensor);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type(results)
}];
}
//===----------------------------------------------------------------------===//
// Other operations
//===----------------------------------------------------------------------===//
def SpatImgConcatOp: SpatOp<"img_concat", []> {
let summary = "Concatenate pixel tiles into a single image";
let description = [{
Concatenate pixel tiles into a single image:
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
The input tiles should be provided in a specific order:
start from the top left pixel,
then continue with the pixel on its right,
and once you finish the first row of pixels, go to the next row.
}];
let arguments = (ins
Variadic<SpatTensor>:$inputs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let extraClassDeclaration = [{
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
}];
}
#endif // SPATIAL_DIALECT_H

View File

@@ -0,0 +1,339 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void SpatialDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
>();
}
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 2
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
// Verify that the matrix shape is (N, M)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
// Verify that the vector shape is (M, 1)
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
// Verify that the output shape is (N, 1)
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 4
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
// Verify that the matrix shape is (N, M, 1, 1)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
// Verify that the vector shape is (1, M, 1, 1)
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
// Verify that the output shape is (1, N, 1, 1)
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
}
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
1. matrix: (N, M); vector: (M, 1); output: (N, 1)
2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1)
*/
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
else if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
else
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
1. vector: (1, N); matrix: (N, M); output: (1, M)
*/
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vector1 = vectorShape[0];
int64_t vectorN = vectorShape[1];
if (vectorN != N || vector1 != 1)
return emitError("vector shape must be (N, 1)");
int64_t output1 = outputShape[0];
int64_t outputM = outputShape[1];
if (outputM != M || output1 != 1)
return emitError("output shape must be (M, 1)");
return success();
}
LogicalResult SpatVAddOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
auto operands = getOperands();
// Check number of operands
if (img_w * img_h * channelTiles != operands.size())
return emitError("Number of operands does not match output image size");
// For each output pixel, check that the inputTiles have a correct shape
for (size_t x = 0; x < img_w; x++) {
for (size_t y = 0; y < img_h; y++) {
size_t channel_counts = 0;
for (size_t t = 0; t < channelTiles; t++) {
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
if (!inputShape)
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
// - CASE2: common case, the channel count is exactly the crossbarSize
if (t == channelTiles - 1 && channelTileRest != 0) {
if (inputChannels != channelTileRest)
return emitError("Invalid channel count for last tile of pixel");
}
else {
if (inputChannels != crossbarSize)
return emitError("Invalid channel count for some pixel tile");
}
channel_counts += inputChannels;
}
if (channel_counts != img_c)
emitError("Invalid number of channels for some pixel");
}
}
return success();
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
auto& block = getBody().front();
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size()) {
return emitError("ComputeOp must have same number of results as yieldOp "
"operands");
}
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
// Same type and compatible shape
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
return emitError("ComputeOp output must be of the same type as yieldOp "
"operand");
}
// Same encoding
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
return emitError("ComputeOp output must have the same encoding as "
"yieldOp operand");
}
}
else {
return emitError("ComputeOp output has an encoding while yieldOp "
"operand does not have one");
}
}
else {
// If result does not have an encoding, yield shouldn't either
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if "
"yieldOp operand has one");
}
}
}
// Check that each block argument is used
for (auto arg : block.getArguments())
if (arg.use_empty())
return emitError("ComputeOp block argument is not used");
return success();
}
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
assert(tile < channelTiles);
assert(x < img_w);
assert(y < img_h);
return operands[tile + x * channelTiles + y * img_w * channelTiles];
}
} // namespace spatial
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"

View File

@@ -0,0 +1,20 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
#define GET_TYPEDEF_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"

View File

@@ -0,0 +1,493 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace spatial {
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
auto resultShape = cast<ShapedType>(resultType);
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other
// Receive/Send operation. However, during bufferization, the first of the
// Receive/Send operation that is processed gets removed. As such, we need to
// "precompute" the coreId needed for the other op, and save it as attribute
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
if (failed(notOpUserOpt))
return failure();
Operation* notOpUser = *notOpUserOpt;
// Save the coreId for this op into the other op as attribute
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
}
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
// Input tensor to the compute OP are always read into its local memory
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensor to the compute OP are _never_ written into its local memory
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the compute OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
// Bufferize its block
auto& block = op->getRegion(0).front();
return bufferizeBlockSignature(&block, rewriter, options, state);
}
};
/*
* This can be used for operation that have a single argument, which is a
* variadic of tensors, and a single output with the same same shape
* Example: VAdd, VSub, VExp
*/
template <typename InterfaceName, typename OpTy, typename ToTy>
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor values into memref values
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
// Turn Tensor Operands into Memref Operands
SmallVector<Value> memrefOperands;
memrefOperands.reserve(op->getNumOperands());
for (auto operand : op->getOperands()) {
auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref))
return failure();
memrefOperands.push_back(*memref);
}
// TODO: Support addiction with more than 2 operands
if (memrefOperands.size() > 2) {
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
"with 1 or 2 operands, for now.");
return failure();
}
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
memrefOperands.push_back(outputTensor);
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
template <typename InterfaceName, typename OpTy, typename ToTy>
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor value into memref value
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(memrefOperandOpt))
return failure();
auto memrefOperand = *memrefOperandOpt;
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue =
rewriter
.create<ToTy>(
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
.getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel receive to pim.recv
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
if (failed(srcCoreId))
return failure();
Value newValue = rewriter
.create<pim::PimReceiveOp>(op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOut();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
if (failed(dstCoreId))
return failure();
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
op,
srcMemRef,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(dstCoreId.value()));
return success();
}
};
struct ChannelBroadcastReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel receive to pim.load using by creating a new global buffer
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
return failure();
}
// The first 'broadcast' operation creates the buffer just after the
// channelNewOp, while the other 'broadcast' operation need to find this
// buffer allocation just after the channelNewOp
Value bufferAllocation;
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
// Buffer already allocated, load from this buffer
bufferAllocation = allocOpAfterChannel;
}
else {
// Buffer was not allocated previously, allocate it after channelNewOp
rewriter.setInsertionPointAfter(channelNewOp);
bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
}
rewriter.setInsertionPoint(op);
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
outputTensor.getType(),
outputTensor,
bufferAllocation,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
return success();
}
};
struct ChannelBroadcastSendOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
return failure();
}
// The first 'broadcast' operation creates the buffer just after the
// channelNewOp, while the other 'broadcast' operation need to find this
// buffer allocation just after the channelNewOp
Value bufferAllocation;
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
// Buffer already allocated, load from this buffer
bufferAllocation = allocOpAfterChannel;
}
else {
// Buffer was not allocated previously, allocate it after channelNewOp
rewriter.setInsertionPointAfter(channelNewOp);
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
}
rewriter.setInsertionPoint(op);
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
return success();
}
};
struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
// One operand ($input) is read from. All other inputs are only written to.
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 0;
}
// One input ($accumBuf) is written to. All other inputs are only read.
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 2;
}
// No operands are aliased with any other operands.
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Bufferize the operation.
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
// Get the input tensor buffer.
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(inputBuffer))
return failure();
// Create a new buffer for the output tensor.
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
// Create a new buffer for the accumulation buffer.
// To do this, create a new allocation operation. Size must be axbx1x1,
// where axbxcxd is the size of the output tensor. Since the shape is
// different, we can't immediately use createEmptyFromType, we first need to
// create the shape of the accumulation buffer.
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
// Set the last two dimensions to 1.
accumShape[accumShape.size() - 1] = 1;
accumShape[accumShape.size() - 2] = 1;
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
// Bufferize the operation.
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
outputTensor.getType(),
weightIndices,
xKernelPositions,
yKernelPositions,
*inputBuffer,
outputTensor,
accumBuffer);
// Replace the operation with the bufferized value.
replaceOpWithBufferizedValues(rewriter, op, bufferized);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
});
}
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
});
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,16 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,67 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct CountInstructionPass
: public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
StringRef getArgument() const override { return "count-instruction-pass"; }
StringRef getDescription() const override {
return "Count instructions for each core/compute in the module";
}
// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
CountInstructionPass() {}
CountInstructionPass(const CountInstructionPass &pass)
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
void runOnOperation() final {
ModuleOp module = getOperation();
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
unsigned totalInstructionCount = 0;
unsigned computeId = 0;
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount
<< " instructions\n";
totalInstructionCount += instructionCount;
computeId++;
}
unsigned coreId = 0;
for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
unsigned instructionCount = 0;
instructionCount += coreOp.getBody().front().getOperations().size();
llvm::outs() << "Core " << coreId << ": " << instructionCount
<< " instructions\n";
totalInstructionCount += instructionCount;
coreId++;
}
llvm::outs() << "Total instruction count: " << totalInstructionCount
<< "\n";
}
};
} // namespace
std::unique_ptr<Pass> createCountInstructionPass() {
return std::make_unique<CountInstructionPass>();
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,37 @@
#include "mlir/Pass/Pass.h"
#include "src/Compiler/CompilerUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MessagePass : public PassWrapper<MessagePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass)
StringRef getArgument() const override { return "message-pass"; }
StringRef getDescription() const override {
return "Lower ONNX ops to Spatial ops.";
}
// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
MessagePass(std::string message) : message(message) {}
MessagePass(const MessagePass &pass)
: PassWrapper<MessagePass, OperationPass<ModuleOp>>() {}
void runOnOperation() final { showCompilePhase(message); }
private:
std::string message;
};
} // namespace
std::unique_ptr<Pass> createMessagePass(std::string message) {
return std::make_unique<MessagePass>(message);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,22 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include <memory>
using namespace mlir;
namespace onnx_mlir {
std::unique_ptr<Pass> createONNXToSpatialPass();
std::unique_ptr<Pass> createSpatialToGraphvizPass();
std::unique_ptr<Pass> createSpatialToPIMPass();
std::unique_ptr<Pass> createBufferizePimPass();
std::unique_ptr<Pass> createMessagePass(std::string message);
std::unique_ptr<Pass> createCountInstructionPass();
} // namespace onnx_mlir

110
src/PIM/PimAccelerator.cpp Normal file
View File

@@ -0,0 +1,110 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/PimAccelerator.hpp"
#include <memory>
#define DEBUG_TYPE "PimAccelerator"
namespace onnx_mlir {
namespace accel {
Accelerator *createPIM() { return PimAccelerator::getInstance(); }
PimAccelerator *PimAccelerator::instance = nullptr;
PimAccelerator *PimAccelerator::getInstance() {
if (instance == nullptr)
instance = new PimAccelerator();
return instance;
}
PimAccelerator::PimAccelerator() : Accelerator(Accelerator::Kind::PIM) {
LLVM_DEBUG(llvm::dbgs() << "Creating a PIM accelerator\n");
acceleratorTargets.push_back(this);
};
PimAccelerator::~PimAccelerator() { delete instance; }
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
addPassesPim(module, pm, emissionTarget, outputNameNoExt);
}
void PimAccelerator::registerDialects(mlir::DialectRegistry &registry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
registry.insert<bufferization::BufferizationDialect>();
registry.insert<pim::PimDialect>();
registry.insert<spatial::SpatialDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerBufferizableOpInterfaceExternalModels(registry);
}
void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
// Register here all the passes that could be used
mlir::registerPass(createONNXToSpatialPass);
mlir::registerPass(createSpatialToGraphvizPass);
mlir::registerPass(createSpatialToPIMPass);
mlir::registerPass(createBufferizePimPass);
}
void PimAccelerator::configurePasses() const {
LLVM_DEBUG(llvm::dbgs() << "Configuring passes for PIM accelerator\n");
// TODO: This does nothing for now.
}
mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(
const mlir::TensorType tensorType) const {
// Do not convert tensor types to memref types.
return nullptr;
}
void PimAccelerator::conversionTargetONNXToKrnl(
mlir::ConversionTarget &target) const {
target.addLegalDialect<pim::PimDialect>();
}
void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const {
// TODO: Add patterns for conversion
}
void PimAccelerator::conversionTargetKrnlToLLVM(
mlir::ConversionTarget &target) const {}
void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns,
mlir::LLVMTypeConverter &typeConverter, mlir::MLIRContext *ctx) const {
// We should not need this, since we offload it all to PIM.
}
} // namespace accel
} // namespace onnx_mlir

View File

@@ -0,0 +1,70 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "src/Accelerators/Accelerator.hpp"
namespace onnx_mlir {
namespace accel {
/// Singleton class to construct PIM accelerator.
class PimAccelerator final : public Accelerator {
private:
static PimAccelerator *instance;
PimAccelerator();
public:
/// Singleton should not be clonable or assignable.
PimAccelerator(PimAccelerator &) = delete;
void operator=(const PimAccelerator &) = delete;
~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance.
static PimAccelerator *getInstance();
/// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc.
static bool classof(const Accelerator *accel) {
return accel->getKind() == Accelerator::Kind::PIM;
}
static bool classof(const PimAccelerator *) { return true; }
uint64_t getVersionNumber() const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void registerDialects(mlir::DialectRegistry &registry) const final;
virtual void registerPasses(int optLevel) const final;
//===--------------------------------------------------------------------===//
// Hooks for both onnx-mlir and onnx-mlir-opt drivers
//===--------------------------------------------------------------------===//
virtual void configurePasses() const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-to-krnl pass
//===--------------------------------------------------------------------===//
virtual mlir::MemRefType convertTensorTypeToMemRefType(
const mlir::TensorType tensorType) const final;
virtual void conversionTargetONNXToKrnl(
mlir::ConversionTarget &target) const final;
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final;
//===--------------------------------------------------------------------===//
// Hooks for krnl-to-llvm pass
//===--------------------------------------------------------------------===//
virtual void conversionTargetKrnlToLLVM(
mlir::ConversionTarget &target) const final;
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns,
mlir::LLVMTypeConverter &typeConverter,
mlir::MLIRContext *ctx) const final;
};
} // namespace accel
} // namespace onnx_mlir

View File

@@ -0,0 +1,87 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
// Do One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure();
}
// Remove toTensor operations
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
toTensorOp.erase();
});
// Change main function return types from tensors to memrefs
func::FuncOp funcOp;
for (Operation& op : moduleOp.getBody()->getOperations())
if ((funcOp = dyn_cast<func::FuncOp>(&op)))
break;
auto oldFuncType = funcOp.getFunctionType();
SmallVector<Type> newResults;
bool changed = false;
for (Type type : oldFuncType.getResults())
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
changed = true;
}
else
newResults.push_back(type);
if (changed)
funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults));
annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug
ModuleOp module = getOperation();
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/pim_buf.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *module;
os.flush();
file.close();
}
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
assert("Weights must be constants" && globalMemrefOp.getConstant());
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
}
});
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace pim {
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace pim
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
} // namespace onnx_mlir

0
test/PIM/CMakeLists.txt Normal file
View File