Compare commits
11 Commits
584ca0b3c2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
568529ea5f | ||
|
|
ca2e1645bb | ||
|
|
6933804003 | ||
|
|
dbe646ac0d | ||
|
|
bb6dcd38a3 | ||
|
|
916a09414c | ||
|
|
db3f52a647 | ||
|
|
6e1de865bb | ||
|
|
4e50e056e3 | ||
|
|
771b44a2ed | ||
|
|
7ce1d2b34d |
23
.github/workflows/build_mlir_cache.yml
vendored
23
.github/workflows/build_mlir_cache.yml
vendored
@@ -12,29 +12,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Free disk space
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
df -h
|
||||
sudo apt-get remove -y '^dotnet-.*'
|
||||
sudo apt-get remove -y '^llvm-.*'
|
||||
sudo apt-get remove -y 'php.*'
|
||||
sudo apt-get remove -y '^mongodb-.*'
|
||||
sudo apt-get remove -y '^mysql-.*'
|
||||
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
|
||||
sudo apt-get autoremove -y
|
||||
sudo apt-get clean
|
||||
df -h
|
||||
sudo rm -rf /usr/local/lib/android || true
|
||||
sudo rm -rf /usr/share/dotnet || true
|
||||
sudo rm -rf /opt/ghc || true
|
||||
sudo rm -rf /usr/local/.ghcup || true
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
|
||||
sudo docker system prune --all --volumes --force
|
||||
sudo apt-get clean
|
||||
sudo rm -rf /var/lib/apt/lists/*
|
||||
df -h
|
||||
|
||||
- name: Cache MLIR build
|
||||
id: cache-mlir
|
||||
uses: actions/cache@v4
|
||||
|
||||
24
.github/workflows/validate_operations.yml
vendored
24
.github/workflows/validate_operations.yml
vendored
@@ -29,33 +29,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Free disk space
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
df -h
|
||||
sudo apt-get remove -y '^dotnet-.*'
|
||||
sudo apt-get remove -y '^llvm-.*'
|
||||
sudo apt-get remove -y 'php.*'
|
||||
sudo apt-get remove -y '^mongodb-.*'
|
||||
sudo apt-get remove -y '^mysql-.*'
|
||||
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
|
||||
sudo apt-get autoremove -y
|
||||
sudo apt-get clean
|
||||
df -h
|
||||
sudo rm -rf /usr/local/lib/android || true
|
||||
sudo rm -rf /usr/share/dotnet || true
|
||||
sudo rm -rf /opt/ghc || true
|
||||
sudo rm -rf /usr/local/.ghcup || true
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
|
||||
sudo docker system prune --all --volumes --force
|
||||
sudo apt-get clean
|
||||
sudo rm -rf /var/lib/apt/lists/*
|
||||
df -h
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
github-server-url: https://chef.heaplab.deib.polimi.it/git
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
.idea
|
||||
.claude
|
||||
AGENTS.md
|
||||
build
|
||||
|
||||
Submodule onnx-mlir updated: 82018d7ce5...eb54c2afc4
@@ -20,6 +20,8 @@ add_onnx_mlir_library(OMPIMAccel
|
||||
Pass/CountInstructionPass.cpp
|
||||
Pass/EmitPimJsonPass.cpp
|
||||
Pass/MessagePass.cpp
|
||||
Pass/PimConstantFoldingPass.cpp
|
||||
Pass/PimHostVerificationPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
@@ -41,6 +43,7 @@ add_onnx_mlir_library(OMPIMAccel
|
||||
PimOps
|
||||
OMONNXToSpatial
|
||||
OMSpatialToGraphviz
|
||||
OMSpatialToPIM
|
||||
OMPIMCommon
|
||||
)
|
||||
OMSpatialToPim
|
||||
OMPimCommon
|
||||
MLIRTensorInferTypeOpInterfaceImpl
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
add_onnx_mlir_library(OMPIMCommon
|
||||
PIMCommon.cpp
|
||||
add_onnx_mlir_library(OMPimCommon
|
||||
PimCommon.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||
std::string dialectsDir = getOutputDir() + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
os << *moduleOp;
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
// channelNewOp should have two users: `op` and a
|
||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
auto secondUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"more than two found.");
|
||||
return failure();
|
||||
}
|
||||
Operation* notOpUser;
|
||||
if (firstUser == op) {
|
||||
notOpUser = secondUser;
|
||||
}
|
||||
else if (secondUser == op) {
|
||||
notOpUser = firstUser;
|
||||
}
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"and one of them must be me, but"
|
||||
"none of them is actually me.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive) {
|
||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
else {
|
||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,24 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
llvm::FailureOr<mlir::Operation*>
|
||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
239
src/PIM/Common/PimCommon.cpp
Normal file
239
src/PIM/Common/PimCommon.cpp
Normal file
@@ -0,0 +1,239 @@
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() {
|
||||
if (outputBaseName.empty() || outputBaseName == "-")
|
||||
return {};
|
||||
|
||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||
if (lastSlash == std::string::npos)
|
||||
return ".";
|
||||
return outputBaseName.substr(0, lastSlash);
|
||||
}
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
os << *moduleOp;
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||
if (entryPoints.size() > 1) {
|
||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||
return failure();
|
||||
}
|
||||
if (!entryPoints.empty()) {
|
||||
auto entryPointAttr =
|
||||
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||
if (!entryPointAttr) {
|
||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||
return failure();
|
||||
}
|
||||
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||
if (!entryFunc) {
|
||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||
<< entryPointAttr.getLeafReference().getValue();
|
||||
return failure();
|
||||
}
|
||||
return entryFunc;
|
||||
}
|
||||
|
||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
||||
return mainGraphFunc;
|
||||
|
||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||
if (!funcOp.isExternal())
|
||||
nonExternalFuncs.push_back(funcOp);
|
||||
if (nonExternalFuncs.size() == 1)
|
||||
return nonExternalFuncs.front();
|
||||
|
||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
// channelNewOp should have two users: `op` and a
|
||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
auto secondUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"more than two found.");
|
||||
return failure();
|
||||
}
|
||||
Operation* notOpUser;
|
||||
if (firstUser == op) {
|
||||
notOpUser = secondUser;
|
||||
}
|
||||
else if (secondUser == op) {
|
||||
notOpUser = firstUser;
|
||||
}
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"and one of them must be me, but"
|
||||
"none of them is actually me.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive) {
|
||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
else {
|
||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
||||
SmallVector<int64_t> indices(shape.size(), 0);
|
||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||
indices[dim] = linearIndex / stride;
|
||||
linearIndex %= stride;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
||||
int64_t linearIndex = 0;
|
||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||
linearIndex += index * stride;
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t dim : shape)
|
||||
numElements *= dim;
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
ArrayRef<int64_t> offsets,
|
||||
ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides) {
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
51
src/PIM/Common/PimCommon.hpp
Normal file
51
src/PIM/Common/PimCommon.hpp
Normal file
@@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
llvm::FailureOr<mlir::Operation*>
|
||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -13,11 +13,11 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
@@ -49,8 +49,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (!getGlobalOp->hasAttr("weightAlways")) {
|
||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
if (!hasWeightAlways(getGlobalOp)) {
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
auto iter = globalConstants.find(globalMemrefOp);
|
||||
if (iter == globalConstants.end())
|
||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
||||
@@ -81,7 +81,7 @@ MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
||||
}
|
||||
|
||||
@@ -112,10 +112,33 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
}
|
||||
value = source;
|
||||
}
|
||||
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
}
|
||||
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
}
|
||||
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
return memEntriesMap.at(value).address + offset;
|
||||
|
||||
auto iter = memEntriesMap.find(value);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
value.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = value.getDefiningOp()) {
|
||||
errs() << "Defining op:\n";
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + offset;
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
@@ -348,6 +371,52 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
|
||||
}
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
auto srcAddr = memory.getValueAddress(transposeOp.getData());
|
||||
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
|
||||
|
||||
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
|
||||
auto srcShape = srcType.getShape();
|
||||
size_t rank = srcShape.size();
|
||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||
size_t totalElements = srcType.getNumElements();
|
||||
|
||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||
SmallVector<int64_t> perm =
|
||||
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
||||
|
||||
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||
SmallVector<int64_t> dstShape(rank);
|
||||
for (size_t i = 0; i < rank; i++)
|
||||
dstShape[i] = srcShape[perm[i]];
|
||||
|
||||
// Row-major strides for source and destination
|
||||
SmallVector<size_t> srcStrides(rank, 1);
|
||||
SmallVector<size_t> dstStrides(rank, 1);
|
||||
for (int64_t i = rank - 2; i >= 0; i--) {
|
||||
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
|
||||
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
|
||||
}
|
||||
|
||||
// Emit element-by-element copy with transposed addressing
|
||||
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||
// Decompose flat source index into multi-dimensional index
|
||||
SmallVector<size_t> srcIdx(rank);
|
||||
size_t remaining = srcFlat;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
srcIdx[d] = remaining / srcStrides[d];
|
||||
remaining %= srcStrides[d];
|
||||
}
|
||||
|
||||
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
|
||||
size_t dstFlat = 0;
|
||||
for (size_t d = 0; d < rank; d++)
|
||||
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
||||
|
||||
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
||||
}
|
||||
}
|
||||
|
||||
size_t getMatrixSize(ShapedType matrixShape) {
|
||||
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
|
||||
assert(false && "Unsupported matrix shape");
|
||||
@@ -378,9 +447,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (getGlobalOp->hasAttr("weightAlways"))
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
@@ -416,7 +485,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
size_t processedOperations = 0;
|
||||
for (auto& op : coreOp.getBody().front()) {
|
||||
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op))
|
||||
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
||||
@@ -435,6 +504,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
||||
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
|
||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
|
||||
coreCodeGen.codeGenVAddOp(vaddOp);
|
||||
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
|
||||
@@ -475,7 +546,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
||||
continue;
|
||||
}
|
||||
|
||||
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr());
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
|
||||
weightIndex++;
|
||||
@@ -589,9 +660,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
}
|
||||
}
|
||||
|
||||
auto funcOps = moduleOp.getOps<func::FuncOp>();
|
||||
assert(!funcOps.empty() && "No function found in the module");
|
||||
auto funcOp = *funcOps.begin();
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc))
|
||||
return CompilerFailure;
|
||||
auto funcOp = *entryFunc;
|
||||
|
||||
PimAcceleratorMemory memory;
|
||||
memory.hostMem.allocateHost(moduleOp, funcOp);
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "Common/ValueMap.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -49,7 +49,7 @@ public:
|
||||
PimAcceleratorMemory()
|
||||
: hostMem(memEntriesMap) {}
|
||||
|
||||
PimMemory getOrCreateDeviceMem(size_t id);
|
||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||
|
||||
size_t getValueAddress(mlir::Value value) const;
|
||||
};
|
||||
@@ -95,6 +95,7 @@ public:
|
||||
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
|
||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
@@ -25,7 +25,6 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
extern llvm::cl::opt<bool> exportCrossbarWeights;
|
||||
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
@@ -34,7 +34,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPim) {
|
||||
pm.addPass(createSpatialToPIMPass());
|
||||
pm.addPass(createSpatialToPimPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||
}
|
||||
@@ -46,6 +46,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createPimHostVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim host verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
add_subdirectory(ONNXToSpatial)
|
||||
add_subdirectory(SpatialToGraphviz)
|
||||
add_subdirectory(SpatialToPIM)
|
||||
add_subdirectory(SpatialToPim)
|
||||
@@ -5,17 +5,13 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
add_onnx_mlir_library(OMONNXToSpatial
|
||||
Math/Gemm.cpp
|
||||
Math/Conv.cpp
|
||||
Math/ExperimentalConv.cpp
|
||||
Math/ExperimentalGemm.cpp
|
||||
NN/Pooling.cpp
|
||||
NN/ExperimentalPooling.cpp
|
||||
NN/ReduceMean.cpp
|
||||
Tensor/ONNXConcatToTensorConcat.cpp
|
||||
Tensor/RemoveUnusedHelperOps.cpp
|
||||
Utils/SpatialReducer.cpp
|
||||
Utils/WeightSubdivider.cpp
|
||||
Utils/AnnotateReplication.cpp
|
||||
ONNXToSpatialPass.hpp
|
||||
ONNXToSpatialPass.cpp
|
||||
ONNXToSpatialCommon.cpp
|
||||
|
||||
@@ -27,7 +23,7 @@ add_onnx_mlir_library(OMONNXToSpatial
|
||||
OMPimCompilerOptions
|
||||
OMONNXOps
|
||||
SpatialOps
|
||||
OMPIMCommon
|
||||
OMPimCommon
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
|
||||
@@ -1,583 +1,273 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
// NOTE:
|
||||
// This might be useful to re-implement this considering for loops.
|
||||
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
|
||||
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
/**
|
||||
* @brief A momentary representation of a core, to be used within the tiling of
|
||||
* a convolution operation.
|
||||
*/
|
||||
class Core {
|
||||
public:
|
||||
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
|
||||
: coreId(coreId), rewriter(rewriter) {}
|
||||
|
||||
/**
|
||||
* @brief Add a MVM operation to the core.
|
||||
*
|
||||
* @param inputTile The input tile to the MVM operation.
|
||||
* @param xbarIndex The index of the crossbar weight to use.
|
||||
* @param outputTileId The id of the output tile.
|
||||
* @param mvmOutType The result's shape.
|
||||
* @return Value The result of the MVM operation.
|
||||
*/
|
||||
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
|
||||
// Use the inputTile as the reference location for the MVM operation.
|
||||
Location loc = inputTile.getLoc();
|
||||
|
||||
// Move the insertion point to the end of the block.
|
||||
rewriter.setInsertionPointToEnd(block.get());
|
||||
|
||||
// Add the inputTile to the block arguments, and to the operands.
|
||||
Value operand = operandMap.lookupOrNull(inputTile);
|
||||
if (not operand) {
|
||||
operand = block->addArgument(inputTile.getType(), loc);
|
||||
operands.push_back(inputTile);
|
||||
operandMap.map(inputTile, operand);
|
||||
}
|
||||
|
||||
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
|
||||
// is correct.
|
||||
|
||||
// Construct the MVM operation
|
||||
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
|
||||
|
||||
// Since we are within the same core and no computation can happen in
|
||||
// paralllel, we can just apply a linear reduction in case we have multiple
|
||||
// MVM operations for the same outputTile.
|
||||
auto lastMVM = outputTileToMVM.find(outputTileId);
|
||||
|
||||
// If an entry for this outputTile already exists, apply reduction.
|
||||
if (lastMVM != outputTileToMVM.end()) {
|
||||
// MVM results should have the same type for reduction.
|
||||
assert(lastMVM->second.getType() == result.getType());
|
||||
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
|
||||
}
|
||||
|
||||
outputTileToMVM[outputTileId] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Mark a result as remappable, and return a shared pointer to it.
|
||||
*
|
||||
* This function marks a result as remappable, and returns a shared pointer to
|
||||
* it. We need to keep track of these values to generate the YieldOp at a
|
||||
* later stage.
|
||||
*
|
||||
* @param result A result to track, for later remapping.
|
||||
* @return shared_ptr<Value> A shared pointer to the result.
|
||||
*/
|
||||
shared_ptr<Value> makeResultRemappable(Value result) {
|
||||
// Verify that the result is present in the block.
|
||||
assert(result.getDefiningOp()->getBlock() == block.get());
|
||||
|
||||
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
|
||||
|
||||
resultsToRemap.push_back(remappableResult);
|
||||
results.push_back(result);
|
||||
|
||||
return remappableResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add a remappable operand to the core, to merge partial results
|
||||
* inter-core.
|
||||
*
|
||||
* @param remappableOperand The operand to add.
|
||||
* @return Value The block argument representing the operand.
|
||||
*/
|
||||
Value addRemappableOperand(std::shared_ptr<Value> operand) {
|
||||
// Check that the operand is not already there.
|
||||
assert(not operandMap.contains(*operand));
|
||||
|
||||
Value argument = block->addArgument(operand->getType(), operand->getLoc());
|
||||
remappableOperands.push_back(operand);
|
||||
return argument;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
|
||||
*
|
||||
* @param loc The location of the operation.
|
||||
* @return spatial::SpatWeightedCompute
|
||||
*/
|
||||
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
|
||||
// Get the shape of the results.
|
||||
SmallVector<Type> resultTypes;
|
||||
for (const auto& value : results)
|
||||
resultTypes.push_back(value.getType());
|
||||
|
||||
// Create the WComputeOp, with non-remappable operands only.
|
||||
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
|
||||
|
||||
// Add the body to the WComputeOp.
|
||||
Block* releasedBlock = block.release();
|
||||
wcomputeOp.getBody().push_back(releasedBlock);
|
||||
|
||||
// Add the `yieldOp` at the end, with the results.
|
||||
rewriter.setInsertionPointToEnd(releasedBlock);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, results);
|
||||
|
||||
return wcomputeOp;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Remap the results to the WComputeOp results.
|
||||
*/
|
||||
void remapResults() {
|
||||
// Remap all the results to the WComputeOp results.
|
||||
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
|
||||
for (size_t i = 0; i < resultsToRemap.size(); i++)
|
||||
*resultsToRemap[i] = wcomputeOp.getResult(i);
|
||||
}
|
||||
|
||||
void addRemappedOperands() {
|
||||
// Insert the remappableOperands (which were remapped in
|
||||
// `addRemappableOperand` of another Core)
|
||||
for (auto remappedValue : remappableOperands)
|
||||
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
|
||||
|
||||
// Update the wcomputeOp operandSegmentSize
|
||||
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
|
||||
}
|
||||
|
||||
size_t addXbarWeight(Value weight) {
|
||||
assert(!isXbarsFull());
|
||||
xbarWeights.push_back(weight);
|
||||
return xbarWeights.size() - 1;
|
||||
}
|
||||
|
||||
bool isXbarsFull() {
|
||||
assert(xbarWeights.size() <= crossbarCountInCore);
|
||||
return xbarWeights.size() == crossbarCountInCore;
|
||||
}
|
||||
|
||||
bool isCoreEmpty() { return block->empty(); }
|
||||
|
||||
void dump() {
|
||||
// Print the coreId
|
||||
llvm::outs() << "Core " << coreId << ":\n";
|
||||
// Print the weights
|
||||
llvm::outs() << "Xbar Weights:\n";
|
||||
for (auto weight : xbarWeights)
|
||||
weight.dump();
|
||||
// Print the operands
|
||||
llvm::outs() << "Operands:\n";
|
||||
for (auto operand : operands)
|
||||
llvm::outs() << operand << "\n";
|
||||
|
||||
// Dump the body block
|
||||
for (auto& op : block->getOperations())
|
||||
op.dump();
|
||||
|
||||
// Print the results
|
||||
llvm::outs() << "Results:\n";
|
||||
for (auto result : results)
|
||||
llvm::outs() << result << "\n";
|
||||
}
|
||||
|
||||
const size_t coreId;
|
||||
|
||||
private:
|
||||
ConversionPatternRewriter& rewriter;
|
||||
|
||||
// Should these be set<Value> instead? But I need to keep the order
|
||||
vector<Value> operands;
|
||||
vector<std::shared_ptr<Value>> remappableOperands;
|
||||
|
||||
vector<Value> results;
|
||||
vector<std::shared_ptr<Value>> resultsToRemap;
|
||||
|
||||
// Maps from input tiles to the block operand
|
||||
IRMapping operandMap;
|
||||
|
||||
// Map from outputTileId to MVM operation producing it
|
||||
unordered_map<size_t, Value> outputTileToMVM;
|
||||
|
||||
vector<Value> xbarWeights;
|
||||
|
||||
unique_ptr<mlir::Block> block = make_unique<Block>();
|
||||
|
||||
spatial::SpatWeightedCompute wcomputeOp;
|
||||
LogicalResult matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ONNXConvOpTile(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
} // namespace
|
||||
|
||||
struct Producer_t {
|
||||
Value value;
|
||||
shared_ptr<Core> core;
|
||||
};
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
|
||||
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
|
||||
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
|
||||
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
|
||||
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
|
||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||
|
||||
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
|
||||
// We need to understand what is group
|
||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||
|
||||
// TODO: Pad value at beginning and end of each dimension could be
|
||||
// different. We should handle this case.
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// MapOperations mapOperation = MapOperations::None;
|
||||
//
|
||||
// // If we have just one user, and it is an activation funcion (or more in
|
||||
// // general a mapping operation) just inline it in the computeOps
|
||||
// auto firstUserOp = *conv->getUsers().begin();
|
||||
// if (conv->hasOneUse()) {
|
||||
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
|
||||
//
|
||||
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// conv, "Softmax not supported as activation for convolutions.");
|
||||
// }
|
||||
// }
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
||||
|
||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
||||
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
|
||||
size_t krn_w = GET_KERNEL_WIDTH(wShape);
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
Location loc = conv.getLoc();
|
||||
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
||||
|
||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
||||
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
|
||||
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
// Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt = resolveImgInputTiles(
|
||||
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
|
||||
|
||||
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
|
||||
// Tile the weight tensor
|
||||
// Weight tiles need to be indexed by:
|
||||
// a. Filter Tile
|
||||
// b. Channel Tile
|
||||
// c. Kernel `x` position
|
||||
// d. Kernel `y` position
|
||||
// For example: weightTiles[filterTile][channelTile][x][y]
|
||||
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
|
||||
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
|
||||
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
sizes = {rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
for (size_t i = 0; i < outputTileCount; i++) {
|
||||
if (i == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
|
||||
sizes[1] = rewriter.getIndexAttr(crossbarSize);
|
||||
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
|
||||
for (size_t j = 0; j < inputTileCount; j++) {
|
||||
if (j == inputTileCount - 1 && inputTileRemainder != 0)
|
||||
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
|
||||
for (size_t x = 0; x < krn_w; x++) {
|
||||
for (size_t y = 0; y < krn_h; y++) {
|
||||
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
weightTiles[i][j][x][y] =
|
||||
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Distribute the computation among many compute cores
|
||||
* Try to compute in-core the computation for each output tile, and reduce
|
||||
* over as few cores as possible
|
||||
*/
|
||||
|
||||
// Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Filter Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[filterTile][x][y]
|
||||
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
|
||||
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
|
||||
|
||||
size_t replicationFactor;
|
||||
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
|
||||
replicationFactor = 1;
|
||||
else
|
||||
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
|
||||
// producers[outTile][out_x][out_y][producerIndex]
|
||||
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
|
||||
outputTileCount,
|
||||
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
|
||||
|
||||
// Schedule in cores
|
||||
size_t coreId = 0;
|
||||
vector<shared_ptr<Core>> curCores(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
|
||||
vector<shared_ptr<Core>> cores;
|
||||
|
||||
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
|
||||
|
||||
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
|
||||
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
|
||||
|
||||
RankedTensorType mvmOutType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
|
||||
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
|
||||
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
|
||||
|
||||
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
|
||||
|
||||
vector<size_t> xbarIndexes(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
|
||||
|
||||
size_t out_x = 0;
|
||||
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
|
||||
size_t out_y = 0;
|
||||
|
||||
// I use `replicationFactor` cores. I divide the input_w into
|
||||
// `replicationFactor` slices, and each slice is distributed to a
|
||||
// core. `coreIndex` is the index of the core that will be used
|
||||
// for this slice
|
||||
size_t coreIndex = in_x / replicationSliceSize;
|
||||
assert(coreIndex < replicationFactor);
|
||||
|
||||
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
|
||||
// Adjust the input based on the kernel
|
||||
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
|
||||
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
|
||||
|
||||
// Check if we are within the input image
|
||||
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
|
||||
out_y++;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
|
||||
auto mvm = curCores[coreIndex]->addMVM(
|
||||
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
|
||||
|
||||
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
|
||||
|
||||
out_y++;
|
||||
}
|
||||
out_x++;
|
||||
}
|
||||
|
||||
// Computations for these crossbars are done, check if the cores
|
||||
// crossbars are fully used. If full, swap with new core
|
||||
for (size_t i = 0; i < replicationFactor; i++) {
|
||||
if (curCores[i]->isXbarsFull()) {
|
||||
cores.emplace_back(std::move(curCores[i]));
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& curCore : curCores)
|
||||
if (curCore->isCoreEmpty() == false)
|
||||
cores.emplace_back(std::move(curCore));
|
||||
curCores.clear();
|
||||
// Now, do the reduction of each output pixel tile
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
for (size_t out_x = 0; out_x < output_w; out_x++) {
|
||||
for (size_t out_y = 0; out_y < output_h; out_y++) {
|
||||
// First, check if some producers are within the same core. If this is
|
||||
// true, `Core::addMVM` have already done the reduction within-core.
|
||||
// This means that we only need to consider the last producer for that
|
||||
// core.
|
||||
|
||||
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
|
||||
for (auto producer : producers[outTile][out_x][out_y])
|
||||
withinCoreReducedProducers[producer.core->coreId] = producer;
|
||||
|
||||
// Now, we need to apply inter-core reduction
|
||||
|
||||
// Base case with one producer
|
||||
if (withinCoreReducedProducers.size() == 1) {
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
auto singleProducer = withinCoreReducedProducers.begin()->second;
|
||||
// Use last producer as the final result
|
||||
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: This is a linear reduction, not a tree reduction. We can do
|
||||
// better: a tree reduction would make more computations happen in
|
||||
// parallel.
|
||||
|
||||
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
|
||||
|
||||
auto it = withinCoreReducedProducers.begin();
|
||||
it++;
|
||||
while (it != withinCoreReducedProducers.end()) {
|
||||
|
||||
Producer_t curProducer = it->second;
|
||||
|
||||
shared_ptr<Core> core1;
|
||||
shared_ptr<Core> core2;
|
||||
Value core1Value;
|
||||
Value core2Value;
|
||||
|
||||
auto lastProducerCoreId = lastProducer.core->coreId;
|
||||
auto curProducerCoreId = curProducer.core->coreId;
|
||||
|
||||
assert(lastProducerCoreId != curProducerCoreId
|
||||
&& "We should have already applied within-core reduction, how "
|
||||
"could we have same cores here?");
|
||||
|
||||
// Sort the cores by coreId
|
||||
if (curProducerCoreId < lastProducerCoreId) {
|
||||
core1 = curProducer.core;
|
||||
core1Value = curProducer.value;
|
||||
core2 = lastProducer.core;
|
||||
core2Value = lastProducer.value;
|
||||
}
|
||||
else {
|
||||
core1 = lastProducer.core;
|
||||
core1Value = lastProducer.value;
|
||||
core2 = curProducer.core;
|
||||
core2Value = curProducer.value;
|
||||
}
|
||||
|
||||
auto newCoreRes = core1->makeResultRemappable(core1Value);
|
||||
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
|
||||
|
||||
rewriter.setInsertionPointAfterValue(core2Value);
|
||||
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
|
||||
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
|
||||
|
||||
lastProducer = {vaddRes, core2};
|
||||
|
||||
it++;
|
||||
}
|
||||
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
// Use last producer as the final result
|
||||
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
|
||||
rewriter.setInsertionPointAfter(conv);
|
||||
spatial::SpatWeightedCompute lastWComputeOp;
|
||||
for (auto& core : cores) {
|
||||
lastWComputeOp = core->createWComputeOp(loc);
|
||||
core->remapResults();
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
}
|
||||
|
||||
for (auto& core : cores)
|
||||
core->addRemappedOperands();
|
||||
|
||||
// Set the insertion point after the last WComputeOp.
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
SmallVector<Value> tilesToConcat;
|
||||
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
|
||||
for (size_t outX = 0; outX < output_h; outX++)
|
||||
for (size_t outY = 0; outY < output_w; outY++)
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
||||
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
|
||||
|
||||
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
|
||||
|
||||
// Value outputImage =
|
||||
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
|
||||
|
||||
// If no mapping (activation) was applied, just replace ConvOp
|
||||
// if (mapOperation == MapOperations::None) {
|
||||
// rewriter.replaceOp(conv, outputImage);
|
||||
// } else {
|
||||
// // If mapping was applied, erase ConvOp and replace the mapping op
|
||||
// rewriter.eraseOp(conv);
|
||||
// rewriter.replaceOp(firstUserOp, outputImage);
|
||||
// }
|
||||
|
||||
return success();
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64(*padsAttr, 0);
|
||||
padWidthBegin = getI64(*padsAttr, 1);
|
||||
padHeightEnd = getI64(*padsAttr, 2);
|
||||
padWidthEnd = getI64(*padsAttr, 3);
|
||||
}
|
||||
};
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXConvOpTile>(ctx);
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// Gemm output: [numPatches, cOut]
|
||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||
|
||||
auto elemType = xType.getElementType();
|
||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB) {
|
||||
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
|
||||
gemmC = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
biasType,
|
||||
b,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
}
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
auto im2colComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
|
||||
|
||||
auto* im2colBlock = new Block();
|
||||
im2colBlock->addArgument(x.getType(), loc);
|
||||
im2colComputeOp.getBody().push_back(im2colBlock);
|
||||
rewriter.setInsertionPointToStart(im2colBlock);
|
||||
|
||||
Value paddedInput = im2colBlock->getArgument(0);
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
for (int64_t n = 0; n < batchSize; n++) {
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colComputeOp);
|
||||
|
||||
// Gemm: A @ B + C = im2col @ W^T + b
|
||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2colComputeOp.getResult(0),
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
Value gemmOut = gemmOp.getY();
|
||||
|
||||
auto collectComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
|
||||
|
||||
auto* collectBlock = new Block();
|
||||
collectBlock->addArgument(gemmOut.getType(), loc);
|
||||
collectComputeOp.getBody().push_back(collectBlock);
|
||||
rewriter.setInsertionPointToStart(collectBlock);
|
||||
|
||||
auto gemmOutArg = collectBlock->getArguments().front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
|
||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,400 +0,0 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "Compiler/PimCompilerOptions.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/**
|
||||
* @brief A pattern to tile the convolution operation into a series of compute
|
||||
* units, each one of which applies filters to a subset of the input
|
||||
* tensor. Results are also reduced and concatenated to form the final
|
||||
* output tensor.
|
||||
*/
|
||||
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ExperimentalONNXConvOpTile(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
|
||||
// --------------------------------- //
|
||||
// --- READ OPERATION PARAMETERS --- //
|
||||
// --------------------------------- //
|
||||
|
||||
// To get each crossbar's weights, we need to slice the weights tensor.
|
||||
// - Along the input tiles.
|
||||
// - Along the output tiles.
|
||||
// - Along the filter x position.
|
||||
// - Along the filter y position.
|
||||
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
|
||||
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
|
||||
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
|
||||
|
||||
// TODO: Address bigger batches.
|
||||
assert(GET_IMAGE_N(inputType) == 1
|
||||
&& "Batch size must be 1"
|
||||
"for convolution.");
|
||||
|
||||
// TODO: Address replication.
|
||||
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
|
||||
|
||||
// TODO: Address bias addition.
|
||||
|
||||
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
|
||||
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
|
||||
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
|
||||
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
|
||||
|
||||
// Assert that the kernel is square.
|
||||
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
|
||||
|
||||
// -------------------------------- //
|
||||
// --- SLICE THE WEIGHTS TENSOR --- //
|
||||
// -------------------------------- //
|
||||
|
||||
// The core idea of this stage is classifying the weights by input and
|
||||
// output tile. This is because we want the applyFilters operations to be
|
||||
// tile agnostic, to keep the subsequent lowering stages as simple as
|
||||
// possible. This data structure does this weight classification:
|
||||
// - The outer map is indexed by input tile.
|
||||
// - The inner map is indexed by output tile.
|
||||
// - The SmallVector contains the weights for the filter.
|
||||
map<long, map<long, SmallVector<Value>>> weightsGroups;
|
||||
|
||||
// During all slicing operations within this stage, we'll use the same
|
||||
// strides for all dimensions.
|
||||
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
|
||||
|
||||
ldiv_t itc = inputTileCount;
|
||||
ldiv_t otc = outputTileCount;
|
||||
|
||||
// - Slicing along the input tiles.
|
||||
// - Slicing along the output tiles.
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
|
||||
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
|
||||
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
|
||||
|
||||
// The loop above also sets the crossbar's used width and height,
|
||||
// checking if we're at the last crossbar and if it's incomplete.
|
||||
|
||||
long outputTile = ot;
|
||||
long inputTile = it;
|
||||
|
||||
// Create the slicing sizes.
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
||||
/* 2 */ rewriter.getIndexAttr(1),
|
||||
/* 3 */ rewriter.getIndexAttr(1)};
|
||||
|
||||
// - Slicing along the filter x position.
|
||||
// - Slicing along the filter y position.
|
||||
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
|
||||
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
|
||||
|
||||
// Create the slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
||||
/* 2 */ rewriter.getIndexAttr(filterX),
|
||||
/* 3 */ rewriter.getIndexAttr(filterY)};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
// Add a note to the extractSliceOp, with the filterX and filterY.
|
||||
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Tree reduction for compute reduction should be implemented.
|
||||
|
||||
// -------------------------------- //
|
||||
// --- CREATE ALL COMPUTE UNITS --- //
|
||||
// -------------------------------- //
|
||||
|
||||
// Keep track of input slicing operations to avoid duplication across
|
||||
// all compute units (global slices).
|
||||
map<long, Value> globalSlices;
|
||||
|
||||
// Keep track of all partial compute results.
|
||||
map<long, Value> globalPartialResults;
|
||||
|
||||
// Use a weight subdivider to extract groups of weights for each compute
|
||||
// unit. We'll keep extracting groups until no more weights are left.
|
||||
WeightSubdivider weightSubdivider(weightsGroups);
|
||||
while (!weightSubdivider.isEmpty()) {
|
||||
|
||||
// -------------------------------- //
|
||||
// --- BEGIN A NEW COMPUTE UNIT --- //
|
||||
// -------------------------------- //
|
||||
|
||||
// Get the next group of weights for the compute unit.
|
||||
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
||||
|
||||
SmallVector<Value> computeWeights;
|
||||
SmallVector<Value> computeOperands;
|
||||
|
||||
// ------------------------------ //
|
||||
// --- SLICE THE INPUT TENSOR --- //
|
||||
// ------------------------------ //
|
||||
|
||||
// Note each tile's index in the compute unit arguments.
|
||||
map<long, size_t> inputTileIndices;
|
||||
map<long, size_t> outputTileIndices;
|
||||
map<long, size_t> reductionTileIndices; // Incoming partial results.
|
||||
|
||||
// Iterate over all weights groups for this compute unit.
|
||||
map<long, Value> localSlices; // WRT the current compute unit.
|
||||
for (auto group : weightsGroups) {
|
||||
for (Value weight : group.weights)
|
||||
computeWeights.push_back(weight);
|
||||
|
||||
// There might be multiple weight groups for the same input tile, so if
|
||||
// we've already added the input tile, skip it.
|
||||
if (localSlices.find(group.inputTile) != localSlices.end())
|
||||
continue;
|
||||
|
||||
// We might have already sliced the input tensor for some other compute
|
||||
// unit, so if we have, reuse the slicing operation without creating a
|
||||
// new one.
|
||||
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
|
||||
computeOperands.push_back(globalSlices[group.inputTile]);
|
||||
localSlices[group.inputTile] = globalSlices[group.inputTile];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create the input tensor slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
||||
/* 2 */ rewriter.getIndexAttr(0),
|
||||
/* 3 */ rewriter.getIndexAttr(0)};
|
||||
|
||||
// Create the input tensor slicing sizes.
|
||||
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
|
||||
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
computeOperands.push_back(extractSliceOp);
|
||||
|
||||
// Update slicing maps.
|
||||
globalSlices[group.inputTile] = extractSliceOp;
|
||||
localSlices[group.inputTile] = extractSliceOp;
|
||||
|
||||
// Update the input tile index.
|
||||
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
|
||||
}
|
||||
|
||||
// ------------------------------- //
|
||||
// --- PREPARE THE OUTPUT TYPE --- //
|
||||
// ------------------------------- //
|
||||
|
||||
// Fill the compute output's type by looking at the output tiles.
|
||||
SmallVector<Type> computeOutputType;
|
||||
for (TaggedWeights group : weightsGroups) {
|
||||
|
||||
// There might be multiple weight groups for the same output tile, so if
|
||||
// we've already added the output tile, skip it.
|
||||
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
|
||||
continue;
|
||||
|
||||
// Additionally, after adding the input slices as operands, also add any
|
||||
// compatible partial results from previous compute units.
|
||||
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
|
||||
computeOperands.push_back(globalPartialResults[group.outputTile]);
|
||||
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
|
||||
}
|
||||
|
||||
// Define the output shape for this group.
|
||||
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
|
||||
|
||||
// TODO: Address non-same padding.
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ outputTileSize,
|
||||
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
|
||||
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
|
||||
|
||||
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
|
||||
|
||||
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
||||
|
||||
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
|
||||
}
|
||||
|
||||
// ----------------------------- //
|
||||
// --- FILL THE COMPUTE UNIT --- //
|
||||
// ----------------------------- //
|
||||
|
||||
// Create the compute unit.
|
||||
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
|
||||
|
||||
// Create a new block for the compute unit and add the operands.
|
||||
Block* block = rewriter.createBlock(¤tCompute.getRegion());
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
for (Value operand : computeOperands)
|
||||
block->addArgument(operand.getType(), conv->getLoc());
|
||||
|
||||
// Initialize a map of local partial results.
|
||||
map<long, Value> localPartialResults; // WRT the current compute unit.
|
||||
|
||||
// If we have any reduction tiles, add them to the local partial results.
|
||||
for (auto reductionTileIndex : reductionTileIndices)
|
||||
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
|
||||
|
||||
// Add all the applyFilters operations to the block.
|
||||
for (TaggedWeights group : weightsGroups) {
|
||||
|
||||
// Get the outputType for this group.
|
||||
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
|
||||
|
||||
// Create an apply filters operation.
|
||||
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
|
||||
|
||||
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
|
||||
// ... As many weights as the size of group.weights.
|
||||
SmallVector<long> weightIndices;
|
||||
for (size_t i = 0; i < group.weights.size(); ++i)
|
||||
weightIndices.push_back(group.startingCrossbarIndex + i);
|
||||
|
||||
SmallVector<int64_t> xKerPos;
|
||||
SmallVector<int64_t> yKerPos;
|
||||
for (auto weight : group.weights) {
|
||||
// Assert that the weight is an extract_slice operation.
|
||||
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
|
||||
assert(extractSliceOp && "Weight is not an extract_slice operation.");
|
||||
|
||||
// Get the filter x and y positions from the extract_slice operation.
|
||||
auto offsets = extractSliceOp.getStaticOffsets();
|
||||
xKerPos.push_back(offsets[2]);
|
||||
yKerPos.push_back(offsets[3]);
|
||||
}
|
||||
|
||||
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
|
||||
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
|
||||
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
|
||||
|
||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
|
||||
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
||||
|
||||
// Perform local reduction if necessary.
|
||||
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
|
||||
|
||||
result = rewriter.create<spatial::SpatVAddOp>(
|
||||
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
|
||||
}
|
||||
|
||||
// Update the partial results map.
|
||||
localPartialResults[group.outputTile] = result;
|
||||
}
|
||||
|
||||
// Add a yield operation to the block by concatenating the partial
|
||||
// results.
|
||||
SmallVector<Value> applyFiltersResults;
|
||||
for (size_t i = 0; i < computeOutputType.size(); ++i) {
|
||||
long outputTile;
|
||||
|
||||
// Given an output tile index, find the corresponding output tile.
|
||||
for (auto outputTileIndex : outputTileIndices) {
|
||||
if (outputTileIndex.second == i) {
|
||||
outputTile = outputTileIndex.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get that tile's partial result and add it to the list.
|
||||
applyFiltersResults.push_back(localPartialResults[outputTile]);
|
||||
}
|
||||
|
||||
// Create the yield operation with the given results.
|
||||
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
|
||||
|
||||
// Update the global partial results map.
|
||||
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
|
||||
long outputTile;
|
||||
|
||||
// Given an output tile index, find the corresponding output tile.
|
||||
for (auto outputTileIndex : outputTileIndices) {
|
||||
if (outputTileIndex.second == i) {
|
||||
outputTile = outputTileIndex.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
globalPartialResults[outputTile] = currentCompute.getResult(i);
|
||||
}
|
||||
|
||||
// Move the rewrite cursor out of the block.
|
||||
rewriter.setInsertionPointAfter(currentCompute);
|
||||
}
|
||||
|
||||
// ------------------------------ //
|
||||
// --- CONCATENATE THE OUTPUT --- //
|
||||
// ------------------------------ //
|
||||
|
||||
// Turn the values into a SmallVector.
|
||||
SmallVector<Value> outputValues;
|
||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
|
||||
outputValues.push_back(globalPartialResults[i]);
|
||||
|
||||
// Assert that the number of output values is correct.
|
||||
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
|
||||
|
||||
// If the conv's user is a ReLU...
|
||||
if (conv->hasOneUse()) {
|
||||
Operation* user = *conv->getUsers().begin();
|
||||
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
|
||||
// ...then we can just replace the ReLU with the concatenation.
|
||||
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
||||
|
||||
// And erase the convolution.
|
||||
rewriter.eraseOp(conv);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Return the final output.
|
||||
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Populate the tiling pattern for a convolution operation.
|
||||
*
|
||||
* @param patterns The pattern set to populate.
|
||||
* @param ctx The MLIR context.
|
||||
*/
|
||||
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,365 +0,0 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "Compiler/PimCompilerOptions.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
|
||||
ExperimentalGemmConversionPattern(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
|
||||
// --------------------------------- //
|
||||
// --- READ OPERATION PARAMETERS --- //
|
||||
// --------------------------------- //
|
||||
|
||||
// To get each crossbar's weights, we need to slice the weights tensor.
|
||||
// - Along the input tiles.
|
||||
// - Along the output tiles.
|
||||
// - Along the filter x position.
|
||||
// - Along the filter y position.
|
||||
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
|
||||
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
|
||||
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
|
||||
|
||||
// TODO: Address bigger batches.
|
||||
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
|
||||
|
||||
// TODO: Address replication.
|
||||
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
|
||||
|
||||
// TODO: Address bias addition.
|
||||
|
||||
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
|
||||
|
||||
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
|
||||
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
|
||||
size_t kernelWidth = 1;
|
||||
size_t kernelHeight = 1;
|
||||
|
||||
// Assert that the kernel is square.
|
||||
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
|
||||
|
||||
// -------------------------------- //
|
||||
// --- SLICE THE WEIGHTS TENSOR --- //
|
||||
// -------------------------------- //
|
||||
|
||||
// The core idea of this stage is classifying the weights by input and
|
||||
// output tile. This is because we want the applyFilters operations to be
|
||||
// tile agnostic, to keep the subsequent lowering stages as simple as
|
||||
// possible. This data structure does this weight classification:
|
||||
// - The outer map is indexed by input tile.
|
||||
// - The inner map is indexed by output tile.
|
||||
// - The SmallVector contains the weights for the filter.
|
||||
map<long, map<long, SmallVector<Value>>> weightsGroups;
|
||||
|
||||
// During all slicing operations within this stage, we'll use the same
|
||||
// strides for all dimensions.
|
||||
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
|
||||
|
||||
ldiv_t itc = inputTileCount;
|
||||
ldiv_t otc = outputTileCount;
|
||||
|
||||
// - Slicing along the input tiles.
|
||||
// - Slicing along the output tiles.
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
|
||||
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
|
||||
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
|
||||
|
||||
// The loop above also sets the crossbar's used width and height,
|
||||
// checking if we're at the last crossbar and if it's incomplete.
|
||||
|
||||
long outputTile = ot;
|
||||
long inputTile = it;
|
||||
|
||||
// Create the slicing sizes.
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
||||
/* 2 */ /* rewriter.getIndexAttr(1), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(1) */};
|
||||
|
||||
// - Slicing along the filter x position.
|
||||
// - Slicing along the filter y position.
|
||||
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
|
||||
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
|
||||
|
||||
// Create the slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(filterX), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
// Add a note to the extractSliceOp, with the filterX and filterY.
|
||||
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Tree reduction for compute reduction should be implemented.
|
||||
|
||||
// -------------------------------- //
|
||||
// --- CREATE ALL COMPUTE UNITS --- //
|
||||
// -------------------------------- //
|
||||
|
||||
// Keep track of input slicing operations to avoid duplication across
|
||||
// all compute units (global slices).
|
||||
map<long, Value> globalSlices;
|
||||
|
||||
// Keep track of all partial compute results.
|
||||
map<long, Value> globalPartialResults;
|
||||
|
||||
// Use a weight subdivider to extract groups of weights for each compute
|
||||
// unit. We'll keep extracting groups until no more weights are left.
|
||||
WeightSubdivider weightSubdivider(weightsGroups);
|
||||
while (!weightSubdivider.isEmpty()) {
|
||||
|
||||
// -------------------------------- //
|
||||
// --- BEGIN A NEW COMPUTE UNIT --- //
|
||||
// -------------------------------- //
|
||||
|
||||
// Get the next group of weights for the compute unit.
|
||||
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
||||
|
||||
SmallVector<Value> computeWeights;
|
||||
SmallVector<Value> computeOperands;
|
||||
|
||||
// ------------------------------ //
|
||||
// --- SLICE THE INPUT TENSOR --- //
|
||||
// ------------------------------ //
|
||||
|
||||
// Note each tile's index in the compute unit arguments.
|
||||
map<long, size_t> inputTileIndices;
|
||||
map<long, size_t> outputTileIndices;
|
||||
map<long, size_t> reductionTileIndices; // Incoming partial results.
|
||||
|
||||
// Iterate over all weights groups for this compute unit.
|
||||
map<long, Value> localSlices; // WRT the current compute unit.
|
||||
for (auto group : weightsGroups) {
|
||||
for (Value weight : group.weights)
|
||||
computeWeights.push_back(weight);
|
||||
|
||||
// There might be multiple weight groups for the same input tile, so if
|
||||
// we've already added the input tile, skip it.
|
||||
if (localSlices.find(group.inputTile) != localSlices.end())
|
||||
continue;
|
||||
|
||||
// We might have already sliced the input tensor for some other compute
|
||||
// unit, so if we have, reuse the slicing operation without creating a
|
||||
// new one.
|
||||
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
|
||||
computeOperands.push_back(globalSlices[group.inputTile]);
|
||||
localSlices[group.inputTile] = globalSlices[group.inputTile];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create the input tensor slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(0), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(0) */};
|
||||
|
||||
// Create the input tensor slicing sizes.
|
||||
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
computeOperands.push_back(extractSliceOp);
|
||||
|
||||
// Update slicing maps.
|
||||
globalSlices[group.inputTile] = extractSliceOp;
|
||||
localSlices[group.inputTile] = extractSliceOp;
|
||||
|
||||
// Update the input tile index.
|
||||
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
|
||||
}
|
||||
|
||||
// ------------------------------- //
|
||||
// --- PREPARE THE OUTPUT TYPE --- //
|
||||
// ------------------------------- //
|
||||
|
||||
// Fill the compute output's type by looking at the output tiles.
|
||||
SmallVector<Type> computeOutputType;
|
||||
for (TaggedWeights group : weightsGroups) {
|
||||
|
||||
// There might be multiple weight groups for the same output tile, so if
|
||||
// we've already added the output tile, skip it.
|
||||
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
|
||||
continue;
|
||||
|
||||
// Additionally, after adding the input slices as operands, also add any
|
||||
// compatible partial results from previous compute units.
|
||||
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
|
||||
computeOperands.push_back(globalPartialResults[group.outputTile]);
|
||||
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
|
||||
}
|
||||
|
||||
// Define the output shape for this group.
|
||||
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
|
||||
|
||||
// TODO: Address non-same padding.
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ outputTileSize,
|
||||
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
|
||||
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
|
||||
|
||||
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
|
||||
|
||||
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
||||
|
||||
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
|
||||
}
|
||||
|
||||
// ----------------------------- //
|
||||
// --- FILL THE COMPUTE UNIT --- //
|
||||
// ----------------------------- //
|
||||
|
||||
// Create the compute unit.
|
||||
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
|
||||
|
||||
// Create a new block for the compute unit and add the operands.
|
||||
Block* block = rewriter.createBlock(¤tCompute.getRegion());
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
for (Value operand : computeOperands)
|
||||
block->addArgument(operand.getType(), gemmOp->getLoc());
|
||||
|
||||
// Initialize a map of local partial results.
|
||||
map<long, Value> localPartialResults; // WRT the current compute unit.
|
||||
|
||||
// If we have any reduction tiles, add them to the local partial results.
|
||||
for (auto reductionTileIndex : reductionTileIndices)
|
||||
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
|
||||
|
||||
// Add all the applyFilters operations to the block.
|
||||
for (TaggedWeights group : weightsGroups) {
|
||||
|
||||
// Get the outputType for this group.
|
||||
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
|
||||
|
||||
// Create an apply filters operation.
|
||||
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
|
||||
|
||||
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
|
||||
// ... As many weights as the size of group.weights.
|
||||
SmallVector<long> weightIndices;
|
||||
for (size_t i = 0; i < group.weights.size(); ++i)
|
||||
weightIndices.push_back(group.startingCrossbarIndex + i);
|
||||
|
||||
SmallVector<int64_t> xKerPos;
|
||||
SmallVector<int64_t> yKerPos;
|
||||
for (auto weight : group.weights) {
|
||||
// Assert that the weight is an extract_slice operation.
|
||||
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
|
||||
assert(extractSliceOp && "Weight is not an extract_slice operation.");
|
||||
|
||||
// Get the filter x and y positions from the extract_slice operation.
|
||||
xKerPos.push_back(0);
|
||||
yKerPos.push_back(0);
|
||||
}
|
||||
|
||||
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
|
||||
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
|
||||
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
|
||||
|
||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
|
||||
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
||||
|
||||
// Perform local reduction if necessary.
|
||||
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
|
||||
|
||||
result = rewriter.create<spatial::SpatVAddOp>(
|
||||
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
|
||||
}
|
||||
|
||||
// Update the partial results map.
|
||||
localPartialResults[group.outputTile] = result;
|
||||
}
|
||||
|
||||
// Add a yield operation to the block by concatenating the partial
|
||||
// results.
|
||||
SmallVector<Value> applyFiltersResults;
|
||||
for (size_t i = 0; i < computeOutputType.size(); ++i) {
|
||||
long outputTile;
|
||||
|
||||
// Given an output tile index, find the corresponding output tile.
|
||||
for (auto outputTileIndex : outputTileIndices) {
|
||||
if (outputTileIndex.second == i) {
|
||||
outputTile = outputTileIndex.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get that tile's partial result and add it to the list.
|
||||
applyFiltersResults.push_back(localPartialResults[outputTile]);
|
||||
}
|
||||
|
||||
// Create the yield operation with the given results.
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
|
||||
|
||||
// Update the global partial results map.
|
||||
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
|
||||
long outputTile;
|
||||
|
||||
// Given an output tile index, find the corresponding output tile.
|
||||
for (auto outputTileIndex : outputTileIndices) {
|
||||
if (outputTileIndex.second == i) {
|
||||
outputTile = outputTileIndex.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
globalPartialResults[outputTile] = currentCompute.getResult(i);
|
||||
}
|
||||
|
||||
// Move the rewrite cursor out of the block.
|
||||
rewriter.setInsertionPointAfter(currentCompute);
|
||||
}
|
||||
|
||||
// ------------------------------ //
|
||||
// --- CONCATENATE THE OUTPUT --- //
|
||||
// ------------------------------ //
|
||||
|
||||
// Turn the values into a SmallVector.
|
||||
SmallVector<Value> outputValues;
|
||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
|
||||
outputValues.push_back(globalPartialResults[i]);
|
||||
|
||||
// Assert that the number of output values is correct.
|
||||
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
|
||||
|
||||
// Return the final output.
|
||||
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -2,18 +2,15 @@
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -22,396 +19,403 @@
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
GemmToManyGemv(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 2) {}
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !adaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
// Only decompose when there are multiple rows to split
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
}
|
||||
|
||||
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto gemvOp : gemvOps)
|
||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
GemvToSpatialCompute(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location gemmLoc = gemmOp.getLoc();
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
Value out = gemmOp.getY();
|
||||
|
||||
float alpha = adaptor.getAlpha().convertToFloat();
|
||||
float beta = adaptor.getBeta().convertToFloat();
|
||||
bool transA = adaptor.getTransA();
|
||||
bool transB = adaptor.getTransB();
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(out.getType());
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
auto outLastHSliceSize = cLastHSliceSize;
|
||||
|
||||
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
||||
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
SmallVector<Value> cHSlices;
|
||||
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
||||
if (hasC)
|
||||
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
RankedTensorType outHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
||||
RankedTensorType outLastHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> outHSlices;
|
||||
outHSlices.reserve(outNumHSlices);
|
||||
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
||||
RankedTensorType currOutHSliceType = outHSliceType;
|
||||
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
||||
currOutHSliceType = outLastHSliceType;
|
||||
|
||||
SmallVector<Value> partialResults;
|
||||
partialResults.reserve(coresPerVSlice);
|
||||
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(aHSlices[coreId].size());
|
||||
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
for (auto aHSlice : aHSlices[coreId])
|
||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
auto computeArgs = computeBlock->getArguments();
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(computeArgs.size());
|
||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||
vmmOutputs.push_back(
|
||||
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
}
|
||||
|
||||
if (hasC) {
|
||||
Value cHSlice = cHSlices[outSliceId];
|
||||
partialResults.push_back(cHSlice);
|
||||
}
|
||||
|
||||
auto reduceComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||
|
||||
auto* reduceBlock = new Block();
|
||||
for (auto partialResult : partialResults)
|
||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
||||
rewriter.setInsertionPointToStart(reduceBlock);
|
||||
|
||||
auto blockArgs = reduceBlock->getArguments();
|
||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
|
||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||
|
||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto outHSlice : outHSlices)
|
||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* Resolves the ONNXExpOp from the use chain of the given start value.
|
||||
*
|
||||
* This function traverses the use chain of the start value until it finds an
|
||||
* ONNXExpOp. It returns the value of the ONNXExpOp.
|
||||
*
|
||||
* @param startValue The starting value of the use chain.
|
||||
* @return The value of the ONNXExpOp found in the use chain.
|
||||
*/
|
||||
static Value resolveONNXExpOpFromUseChain(Value startValue) {
|
||||
Value walker = startValue;
|
||||
static Value resolveONNXExpOpFromUseChain(Value startValue);
|
||||
|
||||
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
|
||||
walker = walker.getDefiningOp()->getOperand(0);
|
||||
|
||||
assert(walker && walker.getDefiningOp()
|
||||
&& "Unwinded the whole chain of operations while trying to "
|
||||
"find ONNXExpOp, but did not find it");
|
||||
}
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
|
||||
&& "Old output tile (softmax reducer) is not produced by an "
|
||||
"ONNXExpOp");
|
||||
|
||||
return walker;
|
||||
}
|
||||
|
||||
// Softmax is a special case, as it requires another reduction after the
|
||||
// first one. In the cores, `applyReducePattern` already applied
|
||||
// f(x) = exp(x) to each tile. This mean that now we just need to
|
||||
// reduce-sum these tiles, and then divide each tile by the reduced sum,
|
||||
// which is propagated back to the cores via a broadcast channel.
|
||||
LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc) const {
|
||||
|
||||
// TODO: Check case with one compute op
|
||||
|
||||
// Cast vector of Value into vector of ComputeOp
|
||||
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
|
||||
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
|
||||
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
|
||||
}));
|
||||
|
||||
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
|
||||
const TensorType scalarTensorType = tensorTypeBuilder;
|
||||
|
||||
reducer.applyReducePattern(
|
||||
softmaxOpsToReduce,
|
||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
|
||||
/* preprocess = */
|
||||
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
|
||||
[&](Value softmaxDivisor) {
|
||||
// Signal that this is the compute with the softmax divisor
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
||||
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
|
||||
|
||||
// Broadcast the divisor to all the cores
|
||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
|
||||
|
||||
/*
|
||||
* softmaxDividend = onnx.exp (...)
|
||||
* sum = spat.SumOp(softmaxDividend)
|
||||
* [following can be repeated N times, thus walk the use chain]
|
||||
* softmaxDivisor = spat.sadd(sum, ...)
|
||||
*/
|
||||
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
|
||||
&& "Dividend of softmax reduction is not an ONNXExpOp");
|
||||
|
||||
// Do not divide here, divide after this
|
||||
return softmaxDivisor;
|
||||
});
|
||||
|
||||
// In all the cores, insert a ChannelRecvOp and divide the output tile by
|
||||
// the reduced denominator.
|
||||
outputOpsAndResNums.clear();
|
||||
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
|
||||
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
|
||||
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
|
||||
|
||||
Value divisor;
|
||||
|
||||
// Check if this compute contains the softmax divisor: if so, find the
|
||||
// ChannelBroadcastSendOp, otherwise receive the value from the channel
|
||||
// using ChannelBroadcastReceiveOp
|
||||
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
|
||||
|
||||
bool found = false;
|
||||
for (auto broadcastOp :
|
||||
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
|
||||
assert(found == false
|
||||
&& "More than one ChannelBroadcastSendOp in "
|
||||
"compute? How is this possible?");
|
||||
found = true;
|
||||
|
||||
divisor = broadcastOp.getData();
|
||||
}
|
||||
|
||||
assert(found
|
||||
&& "No ChannelBroadcastSendOp in compute where softmax "
|
||||
"divisor was specified to be?");
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
|
||||
}
|
||||
|
||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
||||
// needed because some some may have a different amount of `VAddOp`s due
|
||||
// to the tree reduction (e.g. some may have no VAddOp, some may have
|
||||
// multiples)
|
||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
||||
|
||||
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
static LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
// Only decompose when there are multiple rows to split
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
}
|
||||
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto gemvOp : gemvOps)
|
||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location gemmLoc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
Value out = gemmOp.getY();
|
||||
|
||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
||||
bool transA = gemmOpAdaptor.getTransA();
|
||||
bool transB = gemmOpAdaptor.getTransB();
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(out.getType());
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
auto outLastHSliceSize = cLastHSliceSize;
|
||||
|
||||
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
||||
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
SmallVector<Value> cHSlices;
|
||||
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
||||
if (hasC)
|
||||
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
RankedTensorType outHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
||||
RankedTensorType outLastHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> outHSlices;
|
||||
outHSlices.reserve(outNumHSlices);
|
||||
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
||||
RankedTensorType currOutHSliceType = outHSliceType;
|
||||
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
||||
currOutHSliceType = outLastHSliceType;
|
||||
|
||||
SmallVector<Value> partialResults;
|
||||
partialResults.reserve(coresPerVSlice);
|
||||
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(aHSlices[coreId].size());
|
||||
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
for (auto aHSlice : aHSlices[coreId])
|
||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
auto computeArgs = computeBlock->getArguments();
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(computeArgs.size());
|
||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
}
|
||||
|
||||
if (hasC) {
|
||||
Value cHSlice = cHSlices[outSliceId];
|
||||
partialResults.push_back(cHSlice);
|
||||
}
|
||||
|
||||
auto reduceComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||
|
||||
auto* reduceBlock = new Block();
|
||||
for (auto partialResult : partialResults)
|
||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
||||
rewriter.setInsertionPointToStart(reduceBlock);
|
||||
|
||||
auto blockArgs = reduceBlock->getArguments();
|
||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||
|
||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto outHSlice : outHSlices)
|
||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value GemvToSpatialCompute::resolveONNXExpOpFromUseChain(Value startValue) {
|
||||
Value walker = startValue;
|
||||
|
||||
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
|
||||
walker = walker.getDefiningOp()->getOperand(0);
|
||||
|
||||
assert(walker && walker.getDefiningOp()
|
||||
&& "Unwinded the whole chain of operations while trying to "
|
||||
"find ONNXExpOp, but did not find it");
|
||||
}
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
|
||||
&& "Old output tile (softmax reducer) is not produced by an "
|
||||
"ONNXExpOp");
|
||||
|
||||
return walker;
|
||||
}
|
||||
|
||||
LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc) {
|
||||
// TODO: Check case with one compute op
|
||||
|
||||
// Cast vector of Value into vector of ComputeOp
|
||||
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
|
||||
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
|
||||
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
|
||||
}));
|
||||
|
||||
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
|
||||
const TensorType scalarTensorType = tensorTypeBuilder;
|
||||
|
||||
reducer.applyReducePattern(
|
||||
softmaxOpsToReduce,
|
||||
[&](Value a, Value b) { return spatial::SpatVAddOp::create(rewriter, loc, scalarTensorType, a, b); },
|
||||
/* preprocess = */
|
||||
[&](Value a) { return spatial::SpatSumOp::create(rewriter, loc, scalarTensorType, a); },
|
||||
[&](Value softmaxDivisor) {
|
||||
// Signal that this is the compute with the softmax divisor
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
||||
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
|
||||
|
||||
// Broadcast the divisor to all the cores
|
||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
||||
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, softmaxChannel, softmaxDivisor);
|
||||
|
||||
/*
|
||||
* softmaxDividend = onnx.exp (...)
|
||||
* sum = spat.SumOp(softmaxDividend)
|
||||
* [following can be repeated N times, thus walk the use chain]
|
||||
* softmaxDivisor = spat.sadd(sum, ...)
|
||||
*/
|
||||
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
|
||||
&& "Dividend of softmax reduction is not an ONNXExpOp");
|
||||
|
||||
// Do not divide here, divide after this
|
||||
return softmaxDivisor;
|
||||
});
|
||||
|
||||
// In all the cores, insert a ChannelRecvOp and divide the output tile by
|
||||
// the reduced denominator.
|
||||
outputOpsAndResNums.clear();
|
||||
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
|
||||
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
|
||||
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
|
||||
|
||||
Value divisor;
|
||||
|
||||
// Check if this compute contains the softmax divisor: if so, find the
|
||||
// ChannelBroadcastSendOp, otherwise receive the value from the channel
|
||||
// using ChannelBroadcastReceiveOp
|
||||
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
|
||||
|
||||
bool found = false;
|
||||
for (auto broadcastOp :
|
||||
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
|
||||
assert(found == false
|
||||
&& "More than one ChannelBroadcastSendOp in "
|
||||
"compute? How is this possible?");
|
||||
found = true;
|
||||
|
||||
divisor = broadcastOp.getData();
|
||||
}
|
||||
|
||||
assert(found
|
||||
&& "No ChannelBroadcastSendOp in compute where softmax "
|
||||
"divisor was specified to be?");
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
divisor = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, loc, scalarTensorType, softmaxChannel);
|
||||
}
|
||||
|
||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
||||
// needed because some some may have a different amount of `VAddOp`s due
|
||||
// to the tree reduction (e.g. some may have no VAddOp, some may have
|
||||
// multiples)
|
||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
Value newOutputTile = spatial::SpatVSDivOp::create(rewriter, loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
||||
|
||||
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
|
||||
@@ -1,300 +0,0 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename PoolOp>
|
||||
bool hasPostProcessExperimentalPoolingWindow() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename ReductionOp>
|
||||
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
|
||||
if (inputTiles.size() == 1)
|
||||
return inputTiles[0];
|
||||
|
||||
if (inputTiles.size() == 2) {
|
||||
return rewriter.create<spatial::SpatVMaxOp>(
|
||||
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
|
||||
}
|
||||
|
||||
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
|
||||
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
|
||||
|
||||
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
|
||||
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
|
||||
|
||||
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
||||
|
||||
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
|
||||
|
||||
// Assert that the input is a tensor.ConcatOp.
|
||||
auto concat = X.getDefiningOp<tensor::ConcatOp>();
|
||||
if (!concat)
|
||||
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
|
||||
|
||||
// Create a [channel_tile][x][y] array to store the input tiles.
|
||||
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
|
||||
|
||||
// For each argument of the tensor.ConcatOp, resolve the input tiles.
|
||||
for (size_t y = 0; y < input_h; ++y) {
|
||||
for (size_t x = 0; x < input_w; ++x) {
|
||||
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
|
||||
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
|
||||
|
||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
|
||||
/* 1 */ rewriter.getIndexAttr(0),
|
||||
/* 2 */ rewriter.getIndexAttr(x),
|
||||
/* 3 */ rewriter.getIndexAttr(y)};
|
||||
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ rewriter.getIndexAttr(1),
|
||||
/* 3 */ rewriter.getIndexAttr(1)};
|
||||
|
||||
// Get the concat's operand that we want to slice.
|
||||
Value concatInput = concat.getOperand(it);
|
||||
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
|
||||
|
||||
inputTiles[it][x][y] = slicedTile;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the shape of the compute's output.
|
||||
ldiv_t itc = tileCount;
|
||||
SmallVector<Type> outputTileTypes;
|
||||
for (size_t y = 0; y < output_h; ++y) {
|
||||
for (size_t x = 0; x < output_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */
|
||||
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
|
||||
/* 2 */ 1,
|
||||
/* 3 */ 1};
|
||||
|
||||
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
|
||||
|
||||
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a plain value list of the input tiles.
|
||||
SmallVector<Value> inputTilesList;
|
||||
for (size_t y = 0; y < input_h; ++y) {
|
||||
for (size_t x = 0; x < input_w; ++x)
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
|
||||
inputTilesList.push_back(inputTiles[it][y][x]);
|
||||
}
|
||||
|
||||
// Create a single compute to calculate the output.
|
||||
auto computeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
|
||||
|
||||
// Create a new block for the compute unit and add the operands.
|
||||
Block* block = rewriter.createBlock(&computeOp.getRegion());
|
||||
|
||||
// Fill the block arguments and keep a reference to them.
|
||||
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
|
||||
for (size_t y = 0; y < input_h; ++y) {
|
||||
for (size_t x = 0; x < input_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
|
||||
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Begin writing in the block.
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
// Go through all pooling blocks.
|
||||
SmallVector<Value> outputTiles;
|
||||
for (size_t y = 0; y < output_h; ++y) {
|
||||
for (size_t x = 0; x < output_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
size_t start_x = x * stride_x;
|
||||
size_t start_y = y * stride_y;
|
||||
size_t end_x = std::min(start_x + krn_w, input_w);
|
||||
size_t end_y = std::min(start_y + krn_h, input_h);
|
||||
|
||||
SmallVector<Value> inputTilesToReduce;
|
||||
for (size_t ky = start_y; ky < end_y; ++ky)
|
||||
for (size_t kx = start_x; kx < end_x; ++kx)
|
||||
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
|
||||
|
||||
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
|
||||
|
||||
// If the reduce op is add, we need to divide the result by the
|
||||
// number of elements in the pooling window.
|
||||
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
|
||||
// Add a spat.const before the computeOp.
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue =
|
||||
rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
RankedTensorType::get({1}, rewriter.getF32Type()),
|
||||
rewriter.getI64IntegerAttr(krn_w * krn_h),
|
||||
rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
|
||||
reduceResult =
|
||||
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
|
||||
}
|
||||
outputTiles.push_back(reduceResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a YieldOp to return the output tiles.
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
|
||||
|
||||
// Set the rewrite cursor right after the computeOp.
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
|
||||
for (size_t y = 0; y < output_h; ++y) {
|
||||
for (size_t x = 0; x < output_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
|
||||
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We'll now create spat.img.concat ops to concatenate the output tiles.
|
||||
SmallVector<Value> outputTilesList;
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
SmallVector<Value> imgConcatTiles;
|
||||
for (size_t y = 0; y < output_h; ++y)
|
||||
for (size_t x = 0; x < output_w; ++x)
|
||||
imgConcatTiles.push_back(computeOutput[it][y][x]);
|
||||
|
||||
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
|
||||
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ (long) tilingSize,
|
||||
/* 2 */ (long) output_w,
|
||||
/* 3 */ (long) output_h};
|
||||
|
||||
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
|
||||
|
||||
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
|
||||
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
|
||||
}
|
||||
|
||||
// Create a new tensor.ConcatOp to concatenate the output tiles.
|
||||
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
|
||||
|
||||
rewriter.replaceOp(poolOp, outputTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<
|
||||
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
|
||||
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
|
||||
ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,7 +15,7 @@
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
@@ -26,8 +26,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
||||
|
||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
@@ -113,15 +111,15 @@ Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
|
||||
// 1. Add a channel before the first computeOp
|
||||
rewriter.setInsertionPoint(firstCompute);
|
||||
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
||||
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Add a sendOp after the first value
|
||||
rewriter.setInsertionPointAfterValue(firstValue);
|
||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
||||
|
||||
// 3. Add a receiveOp after the second value
|
||||
rewriter.setInsertionPointAfterValue(secondValue);
|
||||
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
|
||||
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
||||
|
||||
// 4. Apply reduction between second value and received value
|
||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||
@@ -190,13 +188,14 @@ Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rew
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
||||
loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
@@ -225,12 +224,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
||||
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
||||
size_t input_h = getImageHeight(xShape);
|
||||
size_t input_w = getImageWidth(xShape);
|
||||
size_t output_h = getImageHeight(yShape);
|
||||
size_t output_w = getImageWidth(yShape);
|
||||
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
|
||||
|
||||
// 1: Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
@@ -259,17 +258,18 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Location tileLoc = extractSliceOp.getLoc();
|
||||
|
||||
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
||||
tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
|
||||
Block* tempComputeOpBlock = new Block();
|
||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||
|
||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
|
||||
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
||||
}
|
||||
@@ -358,7 +358,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
Value reducedWithinCompute = applyReducePatternNew(
|
||||
valuesToPool,
|
||||
rewriter,
|
||||
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
|
||||
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
||||
nullptr,
|
||||
postProcessFn);
|
||||
|
||||
@@ -371,16 +371,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
// Create a new channel before the computeOp
|
||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||
auto reduceChannel =
|
||||
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
|
||||
// Send value through the channel
|
||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
||||
|
||||
// Receive after the computeOp
|
||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||
auto receivedValue =
|
||||
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
|
||||
outputTiles[outTile][outX][outY] = receivedValue;
|
||||
}
|
||||
|
||||
@@ -63,16 +63,17 @@ struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV1
|
||||
/*elementType=*/inputTensorType.getElementType());
|
||||
|
||||
// Create the ONNXAveragePoolOp.
|
||||
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
|
||||
resultType,
|
||||
inputTensor,
|
||||
/*auto_pad=*/"NOTSET",
|
||||
/*ceil_mode=*/0,
|
||||
/*count_include_pad=*/1,
|
||||
dilations,
|
||||
/*kernel_shape=*/kernelShape,
|
||||
/*pads=*/pads,
|
||||
/*strides=*/strides);
|
||||
auto averagePool = ONNXAveragePoolOp::create(rewriter,
|
||||
reduceMean.getLoc(),
|
||||
resultType,
|
||||
inputTensor,
|
||||
/*auto_pad=*/"NOTSET",
|
||||
/*ceil_mode=*/0,
|
||||
/*count_include_pad=*/1,
|
||||
dilations,
|
||||
/*kernel_shape=*/kernelShape,
|
||||
/*pads=*/pads,
|
||||
/*strides=*/strides);
|
||||
|
||||
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
|
||||
rewriter.replaceOp(reduceMean, averagePool.getResult());
|
||||
|
||||
@@ -13,9 +13,7 @@ def onnxToArithConstantOp : Pat<
|
||||
(Arith_ConstantOp $value)
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
@@ -31,7 +29,7 @@ def matMulToGemmPattern : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
@@ -39,9 +37,7 @@ def matMulToGemmPattern : Pat<
|
||||
)
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
|
||||
// ONNXConvOp with a bias.
|
||||
@@ -55,9 +51,7 @@ def convAddToConvWithBiasPatternRight : Pat<
|
||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation to ignore (i.e. remove)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
||||
|
||||
@@ -76,4 +70,4 @@ def removeFlattenSameShapePattern : Pat<
|
||||
[(HaveSameStaticShape $flattenOp, $A)]
|
||||
>; // Add closing parenthesis here
|
||||
|
||||
#endif // ONNX_TO_SPATIAL
|
||||
#endif // ONNX_TO_SPATIAL
|
||||
|
||||
@@ -47,7 +47,7 @@ SmallVector<Value> sliceTensor(
|
||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||
|
||||
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides);
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
||||
@@ -100,11 +100,11 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult();
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
|
||||
return rewriter.create<tensor::SplatOp>(loc, type, elementValue);
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
}
|
||||
|
||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
@@ -122,7 +122,7 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
Value a = (*currTensors)[i];
|
||||
Value b = (*currTensors)[i + 1];
|
||||
rewriter.setInsertionPointAfterValue(b);
|
||||
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
|
||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||
nextTensors->push_back(addedValue);
|
||||
}
|
||||
if (currTensors->size() % 2 == 1)
|
||||
@@ -137,10 +137,10 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
|
||||
switch (mapOp) {
|
||||
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
|
||||
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||
case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,10 +180,10 @@ void tileImageTensorByChannel(Value imageTensor,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
|
||||
|
||||
size_t input_h = GET_IMAGE_HEIGHT(imageShape);
|
||||
size_t input_w = GET_IMAGE_WIDTH(imageShape);
|
||||
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize);
|
||||
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize;
|
||||
size_t input_h = getImageHeight(imageShape);
|
||||
size_t input_w = getImageWidth(imageShape);
|
||||
size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
|
||||
size_t tileRest = getImageChannel(imageShape) % tileSize;
|
||||
|
||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
||||
@@ -201,7 +201,7 @@ void tileImageTensorByChannel(Value imageTensor,
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
|
||||
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides);
|
||||
tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,7 +225,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
||||
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
|
||||
|
||||
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat);
|
||||
return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
@@ -271,7 +271,7 @@ Value createExtractSliceImg(Value valToSlice,
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
|
||||
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
|
||||
return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
Value indexImgValue(Value v,
|
||||
@@ -384,7 +384,7 @@ void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
|
||||
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides);
|
||||
inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -452,7 +452,7 @@ LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
|
||||
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
|
||||
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
|
||||
Value shapeTensor =
|
||||
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
|
||||
arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
|
||||
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
|
||||
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
|
||||
|
||||
|
||||
@@ -9,24 +9,55 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#define DEFINE_MAP_OP(opname) opname,
|
||||
|
||||
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
|
||||
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
|
||||
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
|
||||
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
|
||||
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
|
||||
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
|
||||
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
const StringRef REPLICATION_ATTR_NAME = "replication_factor";
|
||||
template <class ShapedType>
|
||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(1);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageN(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
inline constexpr mlir::StringRef REPLICATION_ATTR_NAME = "replication_factor";
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
@@ -58,51 +89,64 @@ constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(const ArrayRef<T> shape) {
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(const ArrayRef<T> shape) {
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(const ArrayRef<T> shape) {
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVVectorShape(const ArrayRef<T> shape) {
|
||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[1] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T getVectorLength(const ArrayRef<T> shape) {
|
||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||
assert(isVectorShape(shape));
|
||||
return shape[0] != 1 ? shape[0] : shape[1];
|
||||
}
|
||||
|
||||
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); }
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
SmallVector<Value>
|
||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
|
||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>>
|
||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc);
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc);
|
||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||
tileMatrix(mlir::Value& matrixToTile,
|
||||
int64_t hSliceSize,
|
||||
int64_t vSliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc);
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter);
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input);
|
||||
mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
|
||||
|
||||
/**
|
||||
* Unpacks an optional pair vector into two size_t values.
|
||||
@@ -126,7 +170,8 @@ void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t
|
||||
*
|
||||
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
|
||||
*/
|
||||
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
|
||||
std::optional<llvm::Twine>
|
||||
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
|
||||
|
||||
/**
|
||||
* Tiles the image tensor by channel.
|
||||
@@ -140,10 +185,10 @@ std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> val
|
||||
* @param tileSize The size of each tile.
|
||||
* @param rewriter The ConversionPatternRewriter used for creating operations.
|
||||
*/
|
||||
void tileImageTensorByChannel(Value imageTensor,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
|
||||
void tileImageTensorByChannel(mlir::Value imageTensor,
|
||||
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
|
||||
size_t tileSize,
|
||||
ConversionPatternRewriter& rewriter);
|
||||
mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
/**
|
||||
* Creates an ImgConcatOp based on the given tiles.
|
||||
@@ -159,10 +204,10 @@ void tileImageTensorByChannel(Value imageTensor,
|
||||
*
|
||||
* @return The created ImgConcatOp.
|
||||
*/
|
||||
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location& loc,
|
||||
Type outputType);
|
||||
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc,
|
||||
mlir::Type outputType);
|
||||
|
||||
/**
|
||||
* @brief Verifies if the given input coordinates and padding values are within
|
||||
@@ -177,7 +222,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
|
||||
* @return LogicalResult Returns success if the coordinates and padding are
|
||||
* within bounds, failure otherwise.
|
||||
*/
|
||||
LogicalResult
|
||||
mlir::LogicalResult
|
||||
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
|
||||
|
||||
/**
|
||||
@@ -207,13 +252,14 @@ verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY,
|
||||
* @return std::optional<llvm::Twine> An error message if the input tensor could
|
||||
* not be resolved into tiles.
|
||||
*/
|
||||
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
|
||||
size_t channelTileCount,
|
||||
size_t channelTileRest,
|
||||
size_t input_w,
|
||||
size_t input_h,
|
||||
mlir::ConversionPatternRewriter& rewriter);
|
||||
std::optional<llvm::Twine>
|
||||
resolveImgInputTiles(mlir::Value wholeInputTensor,
|
||||
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
|
||||
size_t channelTileCount,
|
||||
size_t channelTileRest,
|
||||
size_t input_w,
|
||||
size_t input_h,
|
||||
mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
/**
|
||||
* Computes the boundaries of an image kernel application.
|
||||
@@ -258,6 +304,6 @@ void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcom
|
||||
* @return The index of the result of the operation that produces the specified
|
||||
* value.
|
||||
*/
|
||||
int getResultIndex(Operation* op, Value v);
|
||||
int getResultIndex(mlir::Operation* op, mlir::Value v);
|
||||
|
||||
}; // namespace onnx_mlir
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -8,20 +9,41 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "ONNXToSpatialPass.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace spatial {
|
||||
bool haveSameStaticShape(Value lhs, Value rhs);
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
||||
|
||||
ONNXToSpatialPass() = default;
|
||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -39,15 +61,19 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
|
||||
IRRewriter rewriter(moduleOp);
|
||||
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
|
||||
if (annotateReplication(funcOp, rewriter).failed()) {
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (annotateReplication(*entryFunc, rewriter).failed()) {
|
||||
llvm::dbgs() << "Failed during annotation for replication analysis\n";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>();
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
@@ -61,16 +87,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
|
||||
if (useExperimentalConvImpl) {
|
||||
populateExperimentalTilingConvOpPattern(patterns, ctx);
|
||||
populateExperimentalPoolingTilingPattern(patterns, ctx);
|
||||
populateGemmToConvConversionPattern(patterns, ctx);
|
||||
}
|
||||
else {
|
||||
populateTilingConvOpPattern(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
}
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
populateReduceMeanConversionPattern(patterns, ctx);
|
||||
@@ -83,8 +102,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
// Count the number of compute ops and check they do not exceed the core count
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : funcOp.getFunctionBody().front().getOperations())
|
||||
if (isa<SpatWeightedCompute>(op))
|
||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatWeightedCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
@@ -101,22 +120,21 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
|
||||
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
|
||||
|
||||
annotateWeightsConstants(funcOp);
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial");
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
MLIRContext* ctx = funcOp.getContext();
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); });
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
||||
if (isAlwaysWeight)
|
||||
constantOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using namespace mlir;
|
||||
extern bool haveSameStaticShape(Value lhs, Value rhs);
|
||||
|
||||
namespace spatial {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
||||
|
||||
ONNXToSpatialPass() = default;
|
||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,28 +1,20 @@
|
||||
#pragma once
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
// Experimental patterns.
|
||||
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -10,7 +10,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename OpTy, typename OpAdaptorTy>
|
||||
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> {
|
||||
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
|
||||
RemoveUnusedHelperOps(MLIRContext* ctx)
|
||||
: OpRewritePattern<OpTy>(ctx) {}
|
||||
|
||||
|
||||
@@ -49,11 +49,11 @@ LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& r
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
ShapedType wShape = mlir::cast<ShapedType>(W.getType());
|
||||
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
|
||||
size_t krn_w = GET_KERNEL_WIDTH(wShape);
|
||||
size_t input_w = getImageWidth(xShape);
|
||||
size_t krn_h = getKernelHeight(wShape);
|
||||
size_t krn_w = getKernelWidth(wShape);
|
||||
|
||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t inputTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
||||
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
|
||||
|
||||
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
|
||||
|
||||
@@ -15,21 +15,21 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
|
||||
llvm::SmallPtrSet<mlir::Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
|
||||
|
||||
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
||||
std::function<Value(const Value&)> processFun,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
std::function<mlir::Value(const mlir::Value&)> processFun,
|
||||
mlir::ConversionPatternRewriter& rewriter) {
|
||||
assert(processFun);
|
||||
|
||||
auto computeOp = GET_COMP(computeOpAndResNum);
|
||||
auto resultNum = GET_RES_NUM(computeOpAndResNum);
|
||||
|
||||
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
spatial::SpatYieldOp yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
|
||||
Value result = yieldOp->getOperand(resultNum);
|
||||
mlir::Value result = yieldOp->getOperand(resultNum);
|
||||
rewriter.setInsertionPointAfterValue(result);
|
||||
Value processedResult = processFun(result);
|
||||
mlir::Value processedResult = processFun(result);
|
||||
if (processedResult == result) {
|
||||
// Sometimes we want processedResult to return the same value but do
|
||||
// something else with it (e.g. in softmax we want to broadcast the value
|
||||
@@ -42,10 +42,11 @@ ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum
|
||||
return yieldOp.getNumOperands() - 1;
|
||||
}
|
||||
|
||||
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
OpAndResNum
|
||||
SpatialReducer::applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
||||
std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
|
||||
std::function<mlir::Value(const mlir::Value&)> preprocess,
|
||||
std::function<mlir::Value(const mlir::Value&)> postprocess) {
|
||||
|
||||
if (preprocess)
|
||||
for (auto& computeOpAndResNum : computeOpsAndResNum)
|
||||
@@ -55,18 +56,18 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
std::unordered_map<mlir::Operation*, mlir::Value> lastValueForCompute;
|
||||
for (auto& computeOpAndResNum : computeOpsAndResNum) {
|
||||
auto computeOp = GET_COMP(computeOpAndResNum);
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
|
||||
auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
mlir::Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
|
||||
|
||||
auto it = lastValueForCompute.find(computeOp.getOperation());
|
||||
|
||||
if (it != lastValueForCompute.end()) {
|
||||
// If we have already seen this computeOp, apply the reduction
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
mlir::Value lastWithinComputeValue = it->second;
|
||||
|
||||
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
|
||||
|
||||
@@ -85,12 +86,12 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
|
||||
computeOpsAndResNum.clear();
|
||||
computeOpsAndResNum.reserve(lastValueForCompute.size());
|
||||
for (auto& entry : lastValueForCompute) {
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
|
||||
auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(entry.first);
|
||||
auto valueWithinCompute = entry.second;
|
||||
|
||||
// We check if `valueWithinCompute` is already used by the yieldOp, in that
|
||||
// case no need to add it
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
bool yieldOpUseFound = false;
|
||||
for (auto& use : valueWithinCompute.getUses()) {
|
||||
if (use.getOwner() == yieldOp.getOperation()) {
|
||||
@@ -110,7 +111,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
|
||||
computeOpsAndResNum.push_back({computeOp, resultNum});
|
||||
}
|
||||
|
||||
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
|
||||
mlir::Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
|
||||
|
||||
// Recursive algorithm to reduce the inputs to a single one:
|
||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
||||
@@ -118,7 +119,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
|
||||
// - Repeat until there is only one input left.
|
||||
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
|
||||
while (computeOpsRef.size() > 1) {
|
||||
SmallVector<ComputeAndResNum> nextComputeOps;
|
||||
llvm::SmallVector<ComputeAndResNum> nextComputeOps;
|
||||
nextComputeOps.reserve(computeOpsRef.size() / 2);
|
||||
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
|
||||
auto [firstCompute, firstResultNum] = computeOpsRef[i];
|
||||
@@ -135,23 +136,23 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
|
||||
// the number of results)
|
||||
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
|
||||
|
||||
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
|
||||
auto yieldOpFirstCompute = mlir::cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
|
||||
|
||||
// Add a new operand to the block of the second computeOp
|
||||
Block& secondBlock = secondCompute.getBody().front();
|
||||
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
|
||||
mlir::Block& secondBlock = secondCompute.getBody().front();
|
||||
mlir::Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
|
||||
|
||||
auto secondComputeWeightsNum =
|
||||
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
|
||||
secondCompute->getAttrOfType<mlir::DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
|
||||
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
|
||||
|
||||
// Take the "former-result" from the second computeOp
|
||||
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
|
||||
Value formerRes2 = secondYield.getOperand(secondResultNum);
|
||||
spatial::SpatYieldOp secondYield = mlir::cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
|
||||
mlir::Value formerRes2 = secondYield.getOperand(secondResultNum);
|
||||
|
||||
// Apply reduction operation
|
||||
rewriter.setInsertionPoint(secondYield);
|
||||
Value reduced = reduce(formerRes2, formerRes1);
|
||||
mlir::Value reduced = reduce(formerRes2, formerRes1);
|
||||
|
||||
// Unfortunately, it is not possible to update the result in place,
|
||||
// because we may have already referenced it by <computeOp, resultNum>
|
||||
@@ -219,7 +220,7 @@ void SpatialReducer::finalizeReduceUpdates() {
|
||||
// `opToReplacedCompute`
|
||||
auto toComputeOp = opToReplacedCompute[toOp];
|
||||
if (!toComputeOp)
|
||||
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
|
||||
toComputeOp = mlir::cast<spatial::SpatWeightedCompute>(toOp);
|
||||
|
||||
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
|
||||
|
||||
@@ -234,31 +235,31 @@ void SpatialReducer::finalizeReduceUpdates() {
|
||||
}
|
||||
}
|
||||
|
||||
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
|
||||
mlir::Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
|
||||
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
|
||||
|
||||
Operation* opToCast;
|
||||
mlir::Operation* opToCast;
|
||||
auto it = opToReplacedCompute.find(opAndResNum.first);
|
||||
if (it != opToReplacedCompute.end())
|
||||
opToCast = it->second;
|
||||
else
|
||||
opToCast = opAndResNum.first;
|
||||
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
|
||||
auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(opToCast);
|
||||
|
||||
return computeOp.getResult(opAndResNum.second);
|
||||
}
|
||||
|
||||
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
|
||||
void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) {
|
||||
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
|
||||
// If we have already replaced the fromOp, we do not need to do it again
|
||||
return;
|
||||
}
|
||||
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp);
|
||||
auto oldComputeOp = mlir::cast<spatial::SpatWeightedCompute>(computeOp);
|
||||
|
||||
auto oldComputeOpNum = oldComputeOp->getNumOperands();
|
||||
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
|
||||
auto yieldOp = mlir::cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
|
||||
|
||||
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
|
||||
// No result was added, just add itself to the map
|
||||
@@ -271,8 +272,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
|
||||
|
||||
// Create a new ComputeOp with the new result type, but same operands
|
||||
rewriter.setInsertionPoint(oldComputeOp);
|
||||
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
||||
auto newComputeOp = spatial::SpatWeightedCompute::create(
|
||||
rewriter, oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
||||
|
||||
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
|
||||
|
||||
@@ -283,8 +284,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
|
||||
// Since we replaced the old ComputeOp with a new one, we need to replace
|
||||
// all its results' uses
|
||||
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
|
||||
Value oldResult = oldComputeOp.getResult(i);
|
||||
Value newResult = newComputeOp.getResult(i);
|
||||
mlir::Value oldResult = oldComputeOp.getResult(i);
|
||||
mlir::Value newResult = newComputeOp.getResult(i);
|
||||
|
||||
// Replace the uses, except the uses of the compute ops which got deleted
|
||||
// previously
|
||||
@@ -298,9 +299,10 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
|
||||
rewriter.eraseOp(oldComputeOp);
|
||||
}
|
||||
|
||||
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
|
||||
Location& loc,
|
||||
Type outputType) {
|
||||
mlir::Value
|
||||
SpatialReducer::createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
|
||||
mlir::Location& loc,
|
||||
mlir::Type outputType) {
|
||||
|
||||
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
|
||||
|
||||
@@ -309,8 +311,8 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
|
||||
auto width = outputTiles[0].size();
|
||||
auto height = outputTiles[0][0].size();
|
||||
|
||||
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
|
||||
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
|
||||
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>> remappedOutputTiles(
|
||||
tilesCount, llvm::SmallVector<llvm::SmallVector<mlir::Value>>(width, llvm::SmallVector<mlir::Value>(height)));
|
||||
|
||||
for (size_t t = 0; t < tilesCount; t++)
|
||||
for (size_t x = 0; x < width; x++)
|
||||
@@ -320,25 +322,25 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
|
||||
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
|
||||
}
|
||||
|
||||
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Value biasTile,
|
||||
OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Value biasTile,
|
||||
MapOperations mapOp) {
|
||||
|
||||
std::function<Value(const Value&)> postprocessing = nullptr;
|
||||
std::function<mlir::Value(const mlir::Value&)> postprocessing = nullptr;
|
||||
|
||||
if (mapOp != MapOperations::None) {
|
||||
postprocessing = [&](const Value a) {
|
||||
Value mapOperand = a;
|
||||
postprocessing = [&](const mlir::Value a) {
|
||||
mlir::Value mapOperand = a;
|
||||
if (biasTile)
|
||||
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
|
||||
mapOperand = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, biasTile);
|
||||
return createMapOperation(rewriter, mapOp, mapOperand);
|
||||
};
|
||||
}
|
||||
|
||||
return this->applyReducePattern(
|
||||
computeOps,
|
||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
|
||||
[&](mlir::Value a, mlir::Value b) { return spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); },
|
||||
/* preprocess = */ nullptr,
|
||||
postprocessing);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -13,28 +17,28 @@ using ResNum = unsigned int;
|
||||
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
|
||||
|
||||
struct SpatialReducerChange {
|
||||
Operation* fromOp;
|
||||
mlir::Operation* fromOp;
|
||||
unsigned int fromOpResNum;
|
||||
Operation* toOp;
|
||||
mlir::Operation* toOp;
|
||||
unsigned int toOpOperandNum;
|
||||
};
|
||||
|
||||
using OpAndResNum = std::pair<Operation*, ResNum>;
|
||||
using OpAndResNum = std::pair<mlir::Operation*, ResNum>;
|
||||
|
||||
class SpatialReducer {
|
||||
|
||||
public:
|
||||
SpatialReducer(ConversionPatternRewriter& rewriter)
|
||||
SpatialReducer(mlir::ConversionPatternRewriter& rewriter)
|
||||
: rewriter(rewriter) {}
|
||||
|
||||
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess);
|
||||
OpAndResNum applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
||||
std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
|
||||
std::function<mlir::Value(const mlir::Value&)> preprocess,
|
||||
std::function<mlir::Value(const mlir::Value&)> postprocess);
|
||||
|
||||
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Value biasTile,
|
||||
OpAndResNum applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Value biasTile,
|
||||
MapOperations mapOp);
|
||||
|
||||
void finalizeReduceUpdates();
|
||||
@@ -44,17 +48,17 @@ public:
|
||||
finalizeReduceUpdates();
|
||||
}
|
||||
|
||||
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
|
||||
Location& loc,
|
||||
Type outputType);
|
||||
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
|
||||
mlir::Location& loc,
|
||||
mlir::Type outputType);
|
||||
|
||||
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
|
||||
mlir::Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
|
||||
|
||||
private:
|
||||
[[nodiscard("computeOp result number gets updated")]] ResNum
|
||||
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
||||
std::function<Value(const Value&)> processFun,
|
||||
ConversionPatternRewriter& rewriter);
|
||||
std::function<mlir::Value(const mlir::Value&)> processFun,
|
||||
mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
/**
|
||||
* @brief Update the results of a ComputeOp.
|
||||
@@ -66,19 +70,19 @@ private:
|
||||
*
|
||||
* @param computeOp The ComputeOp to update the results of.
|
||||
*/
|
||||
void updateResultsOfCompute(Operation* computeOp);
|
||||
void updateResultsOfCompute(mlir::Operation* computeOp);
|
||||
|
||||
ConversionPatternRewriter& rewriter;
|
||||
mlir::ConversionPatternRewriter& rewriter;
|
||||
bool reducesFinalized = false;
|
||||
|
||||
// List of changes to be applied after the reduction is finalized
|
||||
SmallVector<SpatialReducerChange, 4> reducerChanges;
|
||||
llvm::SmallVector<SpatialReducerChange, 4> reducerChanges;
|
||||
// List of computeOps that need to be replaced with new results
|
||||
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
|
||||
llvm::SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
|
||||
|
||||
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
|
||||
std::unordered_map<mlir::Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
|
||||
|
||||
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
||||
static llvm::SmallPtrSet<mlir::Operation*, 16> oldComputeOpsReplaced;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
|
||||
WeightSubdivider::WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights)
|
||||
: weights(std::move(weights)) {}
|
||||
|
||||
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
|
||||
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
||||
assert(!weights.empty() && "No weights to extract.");
|
||||
|
||||
auto it = weights.begin();
|
||||
SmallVector<Value>& values = it->second.begin()->second;
|
||||
llvm::SmallVector<mlir::Value>& values = it->second.begin()->second;
|
||||
|
||||
long inputTile = it->first;
|
||||
long outputTile = it->second.begin()->first;
|
||||
@@ -21,7 +21,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
||||
size_t n = std::min(amount, values.size());
|
||||
crossbarsUsed += n;
|
||||
|
||||
SmallVector<Value> result;
|
||||
llvm::SmallVector<mlir::Value> result;
|
||||
result.assign(values.begin(), values.begin() + n);
|
||||
|
||||
if (n < values.size()) {
|
||||
@@ -36,9 +36,9 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
||||
return {inputTile, outputTile, crossbarsUsed - n, result};
|
||||
}
|
||||
|
||||
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
|
||||
llvm::SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
|
||||
crossbarsUsed = 0;
|
||||
SmallVector<TaggedWeights> result;
|
||||
llvm::SmallVector<TaggedWeights> result;
|
||||
size_t remaining = n;
|
||||
|
||||
while (remaining > 0 && !weights.empty()) {
|
||||
|
||||
@@ -4,11 +4,9 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/**
|
||||
@@ -19,7 +17,7 @@ struct TaggedWeights {
|
||||
long inputTile;
|
||||
long outputTile;
|
||||
size_t startingCrossbarIndex;
|
||||
SmallVector<Value> weights;
|
||||
llvm::SmallVector<mlir::Value> weights;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -33,16 +31,16 @@ struct TaggedWeights {
|
||||
*/
|
||||
class WeightSubdivider {
|
||||
private:
|
||||
map<long, map<long, SmallVector<Value>>> weights;
|
||||
std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights;
|
||||
size_t crossbarsUsed = 0;
|
||||
|
||||
TaggedWeights popGroup(size_t amount);
|
||||
|
||||
public:
|
||||
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights);
|
||||
WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights);
|
||||
|
||||
bool isEmpty() const;
|
||||
SmallVector<TaggedWeights> popGroups(size_t n);
|
||||
llvm::SmallVector<TaggedWeights> popGroups(size_t n);
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,7 +5,7 @@ add_onnx_mlir_library(OMSpatialToGraphviz
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMCompilerOptions
|
||||
OMPIMCommon
|
||||
OMPimCommon
|
||||
OMONNXOps
|
||||
SpatialOps
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -199,12 +200,12 @@ private:
|
||||
void SpatialToGraphvizPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Get the first OP, must be a FuncOp
|
||||
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
|
||||
if (!func) {
|
||||
module->emitError("No FuncOp found in the begin of module");
|
||||
auto entryFunc = getPimEntryFunc(module);
|
||||
if (failed(entryFunc)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
func::FuncOp func = *entryFunc;
|
||||
|
||||
os << "digraph G {\n"
|
||||
<< "\tnode [style=filled,color=white];\n";
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
|
||||
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(SpatialToPIMIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMSpatialToPIM
|
||||
SpatialToPIMPass.hpp
|
||||
SpatialToPIMPass.cpp
|
||||
SpatialToPIMCommon.cpp
|
||||
|
||||
DEPENDS
|
||||
SpatialToPIMIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMCompilerOptions
|
||||
OMPIMCommon
|
||||
SpatialOps
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
@@ -1,108 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
*
|
||||
* The static offsets represent the starting position of the slice in each
|
||||
* dimension, while the static tensor input gives its dimension size.
|
||||
*
|
||||
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
||||
* calculated.
|
||||
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
||||
* \return The actual offset of the ExtractSliceOp.
|
||||
*/
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the earliest operation that uses the given value within the value's
|
||||
* block.
|
||||
*
|
||||
* @param value The value for which to find the earliest user operation.
|
||||
* @return The earliest user operation that uses the given value within the
|
||||
* current block.
|
||||
*/
|
||||
Operation* getEarliestUserWithinBlock(Value value);
|
||||
|
||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
|
||||
|
||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
|
||||
|
||||
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
|
||||
const ArrayRef<int64_t> offsets,
|
||||
const ArrayRef<int64_t> sizes,
|
||||
const ArrayRef<int64_t> strides) {
|
||||
// Check that all strides are 1
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
// Check offsets from right to left:
|
||||
// The first offset_n at position n different from 0:
|
||||
// - limits all sizes to the left to 1
|
||||
// - limits size_n to dimension_n - offset_n
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check sizes from right to left:
|
||||
// The first size_n at position n different from shape_n limits all sizes to the left to 1
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
|
||||
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
|
||||
}
|
||||
|
||||
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,60 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace spat_to_pim {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
|
||||
|
||||
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPIMPass() = default;
|
||||
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<Value> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
|
||||
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void addReceiveOps(Value& channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& channelTensorType,
|
||||
bool& useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& tensorType,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace spat_to_pim
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<spat_to_pim::SpatialToPIMPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
20
src/PIM/Conversion/SpatialToPim/CMakeLists.txt
Normal file
20
src/PIM/Conversion/SpatialToPim/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
|
||||
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(SpatialToPimIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMSpatialToPim
|
||||
SpatialToPimPass.cpp
|
||||
SpatialToPimCommon.cpp
|
||||
|
||||
DEPENDS
|
||||
SpatialToPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
SpatialOps
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
@@ -3,10 +3,18 @@
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "mlir/Dialect/Tensor/IR/TensorOps.td"
|
||||
include "src/Dialect/ONNX/ONNX.td"
|
||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def onnxToPimTransposeOp : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVMMOp : Pat<
|
||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
@@ -25,4 +33,4 @@ def spatToPimVAddOp : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
#endif // SPATIAL_TO_PIM
|
||||
@@ -5,9 +5,10 @@
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
#include "SpatialToPIMCommon.hpp"
|
||||
#include "SpatialToPimCommon.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -53,7 +54,7 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(Value value) {
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
assert(!users.empty());
|
||||
@@ -66,23 +67,24 @@ Operation* getEarliestUserWithinBlock(Value value) {
|
||||
return earliestUser;
|
||||
}
|
||||
|
||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
|
||||
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
|
||||
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
||||
auto operandsAndUses =
|
||||
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
|
||||
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
||||
});
|
||||
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||
}
|
||||
|
||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||
Value result = operation->getResult(0);
|
||||
mlir::Value result = operation->getResult(0);
|
||||
auto resultType = result.getType();
|
||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||
|
||||
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
|
||||
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||
auto validOperands =
|
||||
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
|
||||
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
||||
auto bestOperand = validOperands.begin();
|
||||
|
||||
if (bestOperand != validOperands.end())
|
||||
@@ -90,8 +92,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera
|
||||
|
||||
auto resultShapedType = cast<ShapedType>(resultType);
|
||||
rewriter.setInsertionPoint(operation);
|
||||
return rewriter.create<tensor::EmptyOp>(
|
||||
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||
return tensor::EmptyOp::create(
|
||||
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
52
src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp
Normal file
52
src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
*
|
||||
* The static offsets represent the starting position of the slice in each
|
||||
* dimension, while the static tensor input gives its dimension size.
|
||||
*
|
||||
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
||||
* calculated.
|
||||
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
||||
* \return The actual offset of the ExtractSliceOp.
|
||||
*/
|
||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the earliest operation that uses the given value within the value's
|
||||
* block.
|
||||
*
|
||||
* @param value The value for which to find the earliest user operation.
|
||||
* @return The earliest user operation that uses the given value within the
|
||||
* current block.
|
||||
*/
|
||||
mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
|
||||
|
||||
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
||||
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
||||
|
||||
inline mlir::tensor::EmptyOp
|
||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||
}
|
||||
|
||||
inline bool isAConcatOp(mlir::Operation* op) {
|
||||
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,8 +1,10 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -16,20 +18,101 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "SpatialToPIMPass.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
using namespace spat_to_pim;
|
||||
|
||||
void SpatialToPIMPass::runOnOperation() {
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||
|
||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<Value> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void addReceiveOps(Value channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void markOpToRemove(Operation* op);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||
op);
|
||||
}
|
||||
|
||||
static size_t countComputeLeafUsers(Value value) {
|
||||
size_t leafUserCount = 0;
|
||||
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||
leafUserCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isChannelUseChainOp(owner))
|
||||
llvm_unreachable("Channel use chain contains unsupported op");
|
||||
|
||||
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||
self(owner->getResult(0), self);
|
||||
}
|
||||
};
|
||||
|
||||
walkUses(value, walkUses);
|
||||
return leafUserCount;
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
@@ -39,15 +122,21 @@ void SpatialToPIMPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
|
||||
if (!funcOp)
|
||||
llvm_unreachable("No FuncOp found in the begin of module");
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addResultBuffer(returnOp, rewriter);
|
||||
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||
operationsToRemove.push_back(receiveOp);
|
||||
@@ -73,10 +162,10 @@ void SpatialToPIMPass::runOnOperation() {
|
||||
}
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "pim");
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
auto& block = computeOp.getRegion().front();
|
||||
@@ -124,13 +213,14 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
// Store to global memory
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
rewriter.create<PimMemCopyDevToHostOp>(loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
rewriter.getI32IntegerAttr(offset),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize));
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
rewriter.getI32IntegerAttr(offset),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -155,14 +245,14 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
// Store to global memory
|
||||
Value outputTensor = outputTensors[concatIndexInReturn];
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
rewriter.create<PimMemCopyDevToHostOp>(
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
rewriter.getI32IntegerAttr(offset),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
rewriter.getI32IntegerAttr(offset),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -174,23 +264,20 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
// 1. Create a new ChannelOp
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Receive value through the channel
|
||||
// If this result is used by more than one user, then use a "Broadcast"
|
||||
// channel operation. However, there is a special case: we have a single
|
||||
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
|
||||
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
|
||||
// will detect this case and update `useBroadcastOp` accordingly.
|
||||
bool useBroadcastOp = (numResultUses > 1);
|
||||
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
|
||||
// 2. Receive value through the channel. Broadcast is needed whenever the
|
||||
// value eventually reaches more than one compute consumer, even through a
|
||||
// chain of view-like ops.
|
||||
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
|
||||
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
|
||||
|
||||
// 3. Send the value through the channel
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
if (useBroadcastOp)
|
||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue);
|
||||
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
|
||||
else
|
||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
|
||||
}
|
||||
|
||||
// Use `HaltOp` instead of `YieldOp`
|
||||
@@ -199,17 +286,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
|
||||
// Replace `spat.compute` with `pim.core`
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
||||
auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
block.eraseArguments(0, block.getNumArguments());
|
||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||
Block* tempComputeBlock = new Block();
|
||||
computeOp.getBody().push_back(tempComputeBlock);
|
||||
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
||||
rewriter.create<PimHaltOp>(computeOp.getLoc());
|
||||
PimHaltOp::create(rewriter, computeOp.getLoc());
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
@@ -246,20 +333,20 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
||||
rewriter.setInsertionPointAfter(vmmOp);
|
||||
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||
for (auto returnValue : returnOp->getOperands()) {
|
||||
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
assert(!returnValueDefiningOp->hasAttr("weightAlways"));
|
||||
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||
outputTensors.push_back(returnValue);
|
||||
}
|
||||
else {
|
||||
@@ -270,7 +357,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
|
||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||
@@ -279,9 +366,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
||||
|
||||
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType);
|
||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||
|
||||
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>(
|
||||
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
deviceTensor,
|
||||
@@ -301,16 +389,19 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||
|
||||
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
|
||||
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
||||
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||
|
||||
Block& block = funcOp.getBody().front();
|
||||
rewriter.setInsertionPoint(&block.front());
|
||||
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
||||
auto toTensorOp =
|
||||
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
||||
inputTensors.push_back(toTensorOp);
|
||||
|
||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||
funcOp.eraseArgument(i);
|
||||
if (failed(funcOp.eraseArgument(i)))
|
||||
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
||||
}
|
||||
|
||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||
@@ -324,6 +415,9 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||
@@ -357,12 +451,15 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
||||
for (auto sliceOp : sliceOpsToRemove)
|
||||
if (sliceOp->getUses().empty())
|
||||
rewriter.eraseOp(sliceOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& tensorType,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter) {
|
||||
auto& computeBlock = computeOp.getRegion().front();
|
||||
@@ -375,71 +472,71 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
Value receivedValue;
|
||||
if (useBroadcastOp)
|
||||
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
||||
receivedValue =
|
||||
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
else
|
||||
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
|
||||
blockArg.replaceAllUsesWith(receivedValue);
|
||||
Value replacementValue = receivedValue;
|
||||
if (consumerValue != channelSourceOp) {
|
||||
SmallVector<Operation*> clonedChain;
|
||||
Value currentValue = consumerValue;
|
||||
while (currentValue != channelSourceOp) {
|
||||
Operation* definingOp = currentValue.getDefiningOp();
|
||||
if (!definingOp || !isChannelUseChainOp(definingOp))
|
||||
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
|
||||
|
||||
clonedChain.push_back(definingOp);
|
||||
currentValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(channelSourceOp, receivedValue);
|
||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
markOpToRemove(op);
|
||||
}
|
||||
|
||||
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||
}
|
||||
|
||||
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
|
||||
blockArg.replaceAllUsesWith(replacementValue);
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
|
||||
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& channelTensorType,
|
||||
bool& useBroadcastOp,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter) {
|
||||
auto sourceOpUses = channelSourceOp.getUses();
|
||||
|
||||
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
|
||||
if (useBroadcastOp == false) {
|
||||
// if useBroadcastOp is false, then sourceOp must have only one user
|
||||
assert(rangeLength(sourceOpUses) == 1);
|
||||
|
||||
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
|
||||
auto reshapeOpUses = reshapeOp.getOutput().getUses();
|
||||
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
|
||||
if (reshapeOpUsesCount > 1)
|
||||
useBroadcastOp = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& resultUse : sourceOpUses) {
|
||||
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
|
||||
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
|
||||
|
||||
if (computeUser) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!computeUser) {
|
||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
||||
if (!reshapeOp) {
|
||||
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
|
||||
resultUse.getOwner()->dump();
|
||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
||||
}
|
||||
|
||||
// The tensorType now becomes the one of the reshapeOp
|
||||
channelTensorType = reshapeOp.getResult().getType();
|
||||
|
||||
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
|
||||
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
|
||||
|
||||
if (!computeUser)
|
||||
llvm_unreachable("ReshapeOp users must be ComputeOps");
|
||||
|
||||
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
||||
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove the reshapeOp, so that the sourceOp has no users
|
||||
operationsToRemove.push_back(reshapeOp);
|
||||
if (!isChannelUseChainOp(owner))
|
||||
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
|
||||
|
||||
markOpToRemove(owner);
|
||||
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||
self(owner->getResult(0), self);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
if (!llvm::is_contained(operationsToRemove, op))
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
|
||||
@@ -458,7 +555,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||
|
||||
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
||||
|
||||
@@ -468,15 +565,10 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
||||
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
auto tensorType = receiveOp.getType();
|
||||
Value receiveRes = receiveOp.getResult();
|
||||
|
||||
// Check if the receiveOp value has more than one user
|
||||
auto receiveUses = receiveRes.getUses();
|
||||
auto receiveUsesCount = rangeLength(receiveUses);
|
||||
assert(receiveUsesCount > 0);
|
||||
bool useBroadcastOp = receiveUsesCount > 1;
|
||||
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
|
||||
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
|
||||
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
|
||||
|
||||
if (useBroadcastOp) {
|
||||
// When receiving, we actually noticed that the value has more than one
|
||||
@@ -486,3 +578,7 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
||||
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,2 +1,2 @@
|
||||
add_subdirectory(PIM)
|
||||
add_subdirectory(Spatial)
|
||||
add_subdirectory(Pim)
|
||||
add_subdirectory(Spatial)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,34 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Dialect/PIM/PimOps.hpp"
|
||||
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace pim {
|
||||
|
||||
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||
|
||||
PimBufferizationPass() = default;
|
||||
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace pim
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -175,7 +175,33 @@ def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
// Computation
|
||||
// Algebra
|
||||
|
||||
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix transpose
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $data,
|
||||
I64ArrayAttr: $perms,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
@@ -197,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace pim {
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||
|
||||
>();
|
||||
}
|
||||
@@ -45,5 +45,5 @@ POPULATE_DEPENDENCIES(PimVExpOp)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||
@@ -12,7 +12,7 @@
|
||||
#include <string>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"
|
||||
@@ -3,7 +3,6 @@ mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMPimBufferization
|
||||
PimBufferizationPass.hpp
|
||||
PimBufferizationPass.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
@@ -14,7 +13,7 @@ add_onnx_mlir_library(OMPimBufferization
|
||||
PimBufferizationIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPIMCommon
|
||||
OMPimCommon
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref);
|
||||
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
@@ -76,6 +76,32 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto transposeOp = cast<PimTransposeOp>(op);
|
||||
|
||||
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
|
||||
if (failed(dataOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -176,6 +202,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -4,7 +4,7 @@
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def memrefCopyToPimMemCopyOp : Pat<
|
||||
@@ -16,4 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
#endif // PIM_BUFFERIZATION
|
||||
@@ -5,14 +5,39 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
#include "PimBufferizationPass.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||
|
||||
PimBufferizationPass() = default;
|
||||
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimBufferizationPass::runOnOperation() {
|
||||
auto moduleOp = getOperation();
|
||||
|
||||
@@ -64,19 +89,22 @@ void PimBufferizationPass::runOnOperation() {
|
||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "pim_buf");
|
||||
dumpModule(moduleOp, "pim1_buff");
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
MLIRContext* ctx = funcOp.getContext();
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,9 +7,10 @@ add_onnx_mlir_library(SpatialOps
|
||||
Transforms/SpatialBufferizableOpInterface.cpp
|
||||
|
||||
DEPENDS
|
||||
OMONNXIncGen
|
||||
OMSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
OMMlirDialects
|
||||
)
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
|
||||
LogicalResult SpatImgConcatOp::verify() {
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
size_t channelTileRest = img_c % crossbarSize;
|
||||
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
|
||||
return emitError("Invalid input type, must be ShapedType");
|
||||
|
||||
// N == W == H == 1
|
||||
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
|
||||
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
||||
|
||||
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
|
||||
size_t inputChannels = getImageChannel(inputShape);
|
||||
|
||||
// Check the number of channels in this tile are correct:
|
||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
||||
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
||||
auto operands = getOperands();
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -34,7 +34,7 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||
|
||||
// Alloc an output memref
|
||||
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
@@ -134,7 +134,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -169,11 +169,13 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
Value newValue =
|
||||
rewriter
|
||||
.create<ToTy>(
|
||||
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
|
||||
.getOutRes();
|
||||
Value newValue = ToTy::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
cast<OpTy>(op).getWeightIndexAttr(),
|
||||
memrefOperand,
|
||||
outputTensor)
|
||||
.getOutRes();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -213,12 +215,12 @@ struct ChannelReceiveOpInterface
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = rewriter
|
||||
.create<pim::PimReceiveOp>(op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOut();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
@@ -300,7 +302,8 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
@@ -323,13 +326,14 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
bufferAllocation,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
bufferAllocation,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
|
||||
|
||||
@@ -356,7 +360,8 @@ struct ChannelBroadcastSendOpInterface
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
* Turn the channel send into a device-to-host copy into the shared
|
||||
* broadcast buffer that receive ops load from later.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
@@ -389,8 +394,19 @@ struct ChannelBroadcastSendOpInterface
|
||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
bufferAllocation.getType(),
|
||||
bufferAllocation,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -469,14 +485,15 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
|
||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
||||
|
||||
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
weightIndices,
|
||||
xKernelPositions,
|
||||
yKernelPositions,
|
||||
*inputBuffer,
|
||||
outputTensor,
|
||||
accumBuffer);
|
||||
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
weightIndices,
|
||||
xKernelPositions,
|
||||
yKernelPositions,
|
||||
*inputBuffer,
|
||||
outputTensor,
|
||||
accumBuffer);
|
||||
|
||||
// Replace the operation with the bufferized value.
|
||||
replaceOpWithBufferizedValues(rewriter, op, bufferized);
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
618
src/PIM/Pass/PimConstantFoldingPass.cpp
Normal file
618
src/PIM/Pass/PimConstantFoldingPass.cpp
Normal file
@@ -0,0 +1,618 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment = {}) {
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||
|
||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(moduleBuilder,
|
||||
loc,
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
denseAttr,
|
||||
/*constant=*/true,
|
||||
alignment);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return denseAttr;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
|
||||
SmallVector<int64_t> originalStrides(rank, 1);
|
||||
SmallVector<int64_t> transposedStrides(rank, 1);
|
||||
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
||||
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
||||
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
||||
}
|
||||
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
SmallVector<int64_t> transposedIndices(rank);
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedIndices[dim] = originalIndices[perms[dim]];
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
struct ConstantSubviewCopy {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
Operation* copyOp = nullptr;
|
||||
};
|
||||
|
||||
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||
if (!mapOp.getInputs().empty())
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return failure();
|
||||
|
||||
Attribute attr;
|
||||
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
|
||||
return failure();
|
||||
return attr;
|
||||
}
|
||||
|
||||
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
||||
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
||||
if (!coreOp)
|
||||
return failure();
|
||||
|
||||
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
||||
if (!initType || !initType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto fillValue = getConstantMapYield(mapOp);
|
||||
if (failed(fillValue))
|
||||
return failure();
|
||||
|
||||
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
|
||||
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
|
||||
|
||||
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(coreOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
|
||||
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth == 0)
|
||||
return failure();
|
||||
size_t totalBytes = initType.getNumElements() * elementByteWidth;
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
initType,
|
||||
mapOp.getInit(),
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
Value source;
|
||||
SmallVector<int64_t> sourceShape;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
static int64_t
|
||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
}
|
||||
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
|
||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (copyOp.getSize() != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = copyOp.getSrcOffset()
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = copyOp.getDstOffset()
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : copyOp.getDst(),
|
||||
splitSrc ? srcSubview->source : copyOp.getSrc(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numElements = resultTensorType.getNumElements();
|
||||
if (numElements < 0)
|
||||
return failure();
|
||||
|
||||
Attribute fillValue;
|
||||
SmallVector<ConstantSubviewCopy> copies;
|
||||
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
|
||||
SmallVector<Value> pendingAliases;
|
||||
pendingAliases.push_back(allocOp.getResult());
|
||||
|
||||
while (!pendingAliases.empty()) {
|
||||
Value alias = pendingAliases.pop_back_val();
|
||||
for (Operation* user : alias.getUsers()) {
|
||||
if (!visitedAliases.insert(user).second)
|
||||
continue;
|
||||
|
||||
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
|
||||
if (mapOp.getInit() != alias)
|
||||
return failure();
|
||||
auto maybeFillValue = getConstantMapYield(mapOp);
|
||||
if (failed(maybeFillValue))
|
||||
return failure();
|
||||
if (fillValue && fillValue != *maybeFillValue)
|
||||
return failure();
|
||||
fillValue = *maybeFillValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
strides.push_back(*staticStride);
|
||||
}
|
||||
|
||||
for (Operation* subviewUser : subviewOp->getUsers()) {
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
|
||||
if (copyOp.getTarget() != subviewOp.getResult())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
copies.push_back({*denseAttr, offsets, strides, copyOp});
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
|
||||
continue;
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||
pendingAliases.push_back(castOp.getResult());
|
||||
continue;
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
if (!fillValue)
|
||||
return failure();
|
||||
|
||||
SmallVector<Attribute> resultValues(numElements, fillValue);
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
|
||||
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
|
||||
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
|
||||
});
|
||||
|
||||
for (const ConstantSubviewCopy& copy : copies) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|
||||
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
|
||||
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
|
||||
SmallVector<int64_t> sourceIndices =
|
||||
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
|
||||
SmallVector<int64_t> resultIndices;
|
||||
resultIndices.reserve(sourceIndices.size());
|
||||
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
|
||||
resultIndices.push_back(offset + sourceIndex * stride);
|
||||
|
||||
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
|
||||
resultValues[resultLinearIndex] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
|
||||
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!sourceGetGlobal)
|
||||
return failure();
|
||||
|
||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
||||
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> perms;
|
||||
perms.reserve(transposeOp.getPerms().size());
|
||||
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
|
||||
perms.push_back(attr.getInt());
|
||||
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
||||
if (failed(transposedAttr))
|
||||
return failure();
|
||||
|
||||
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
MemRefType globalType = resultType;
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||
transposeOp.getLoc(),
|
||||
globalType,
|
||||
*transposedAttr,
|
||||
sourceGlobal.getName().str() + "__folded_transpose",
|
||||
sourceGlobal.getAlignmentAttr());
|
||||
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
|
||||
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
auto allocType = cast<MemRefType>(allocOp.getType());
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
SmallVector<memref::CastOp> castsToReplace;
|
||||
bool allLiveUsersAreCoreOps = true;
|
||||
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
|
||||
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
|
||||
opsToErase.push_back(user);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||
castsToReplace.push_back(castOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<pim::PimCoreOp>(user))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||
})) {
|
||||
allLiveUsersAreCoreOps = false;
|
||||
}
|
||||
|
||||
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (allLiveUsersAreCoreOps) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
}
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
|
||||
for (memref::CastOp castOp : castsToReplace)
|
||||
preservedUsers.insert(castOp);
|
||||
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
|
||||
|
||||
for (memref::CastOp castOp : castsToReplace) {
|
||||
rewriter.setInsertionPoint(castOp);
|
||||
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
|
||||
rewriter.replaceOp(castOp, replacementCast);
|
||||
if (allLiveUsersAreCoreOps)
|
||||
markWeightAlways(replacementCast.getDefiningOp());
|
||||
}
|
||||
|
||||
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
|
||||
rewriter.eraseOp(subviewUser);
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
owningPatterns
|
||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
|
||||
context);
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim2_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
175
src/PIM/Pass/PimHostVerificationPass.cpp
Normal file
175
src/PIM/Pass/PimHostVerificationPass.cpp
Normal file
@@ -0,0 +1,175 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isAddressOnlyHostOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
memref::AllocOp,
|
||||
memref::GetGlobalOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp,
|
||||
spatial::SpatChannelNewOp>(op);
|
||||
}
|
||||
|
||||
static bool isHostAddressableValue(Value value) {
|
||||
while (true) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-host-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that no runtime host-side code remains in bufferized PIM IR";
|
||||
}
|
||||
|
||||
PimHostVerificationPass() {}
|
||||
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure)
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as memref.get_global before JSON codegen";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
if (!isHostAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlySource(op, subviewOp.getSource());
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
if (isHostAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that still requires host-side execution");
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,23 +3,26 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass();
|
||||
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToGraphvizPass();
|
||||
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPIMPass();
|
||||
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass();
|
||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||
|
||||
std::unique_ptr<Pass> createEmitPimJsonPass();
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<Pass> createMessagePass(std::string message);
|
||||
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
||||
|
||||
std::unique_ptr<Pass> createCountInstructionPass();
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
|
||||
|
||||
std::unique_ptr<mlir::Pass> createCountInstructionPass();
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
@@ -12,8 +13,8 @@
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
@@ -40,28 +41,27 @@ PimAccelerator::PimAccelerator()
|
||||
acceleratorTargets.push_back(this);
|
||||
}
|
||||
|
||||
PimAccelerator::~PimAccelerator() { delete instance; }
|
||||
|
||||
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
|
||||
|
||||
void PimAccelerator::addPasses(OwningOpRef<ModuleOp>& module,
|
||||
PassManager& pm,
|
||||
void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
|
||||
mlir::PassManager& pm,
|
||||
EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
|
||||
addPassesPim(module, pm, emissionTarget, outputNameNoExt);
|
||||
}
|
||||
|
||||
void PimAccelerator::registerDialects(DialectRegistry& registry) const {
|
||||
void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<tosa::TosaDialect>();
|
||||
registry.insert<bufferization::BufferizationDialect>();
|
||||
registry.insert<mlir::tensor::TensorDialect>();
|
||||
registry.insert<mlir::tosa::TosaDialect>();
|
||||
registry.insert<mlir::bufferization::BufferizationDialect>();
|
||||
registry.insert<pim::PimDialect>();
|
||||
registry.insert<spatial::SpatialDialect>();
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
|
||||
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
||||
pim::registerOpBufferizationInterfaces(registry);
|
||||
@@ -71,8 +71,10 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
||||
registerPass(createONNXToSpatialPass);
|
||||
registerPass(createSpatialToGraphvizPass);
|
||||
registerPass(createSpatialToPIMPass);
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createBufferizePimPass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimHostVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
@@ -81,26 +83,26 @@ void PimAccelerator::configurePasses() const {
|
||||
// TODO: This does nothing for now.
|
||||
}
|
||||
|
||||
MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const {
|
||||
mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const {
|
||||
// Do not convert tensor types to memref types.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const {
|
||||
void PimAccelerator::conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const {
|
||||
target.addLegalDialect<pim::PimDialect>();
|
||||
}
|
||||
|
||||
void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns,
|
||||
TypeConverter& typeConverter,
|
||||
MLIRContext* ctx) const {
|
||||
void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
|
||||
mlir::TypeConverter& typeConverter,
|
||||
mlir::MLIRContext* ctx) const {
|
||||
// TODO: Add patterns for conversion
|
||||
}
|
||||
|
||||
void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {}
|
||||
void PimAccelerator::conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const {}
|
||||
|
||||
void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns,
|
||||
LLVMTypeConverter& typeConverter,
|
||||
MLIRContext* ctx) const {
|
||||
void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
|
||||
mlir::LLVMTypeConverter& typeConverter,
|
||||
mlir::MLIRContext* ctx) const {
|
||||
// We should not need this, since we offload it all to PIM.
|
||||
}
|
||||
|
||||
|
||||
@@ -18,8 +18,6 @@ public:
|
||||
PimAccelerator(PimAccelerator&) = delete;
|
||||
void operator=(const PimAccelerator&) = delete;
|
||||
|
||||
~PimAccelerator();
|
||||
|
||||
/// Creates an instance on the first invocation. Subsequent invocations
|
||||
/// return the existing instance.
|
||||
static PimAccelerator* getInstance();
|
||||
|
||||
@@ -102,46 +102,13 @@ def gen_c(inputs, outputs, entry, so_name):
|
||||
if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}}
|
||||
"""))
|
||||
|
||||
# Output printing + optional per-output CSV dump
|
||||
out_blocks=[]
|
||||
# Optional per-output CSV dump
|
||||
csv_write_blocks=[]
|
||||
for oi,name,et,shape in outputs:
|
||||
if et not in DTYPES:
|
||||
raise ValueError(f"Unsupported dtype for output '{name}': {et}")
|
||||
cty, pfmt, _ = DTYPES[et]
|
||||
safe = esc(name)
|
||||
out_blocks.append(textwrap.dedent(f"""
|
||||
// ---- Output {oi}: "{safe}" ----
|
||||
{{
|
||||
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
|
||||
int64_t rank = omTensorGetRank(t);
|
||||
int64_t const *shape = omTensorGetShape(t);
|
||||
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
|
||||
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
|
||||
|
||||
printf("Output {oi} ('{safe}'): shape=[");
|
||||
for (int64_t k=0;k<rank;k++) printf("%ld%s",(long)shape[k], (k+1<rank)?",":"");
|
||||
printf("]\\n");
|
||||
|
||||
if (rank == 2) {{
|
||||
int64_t R = shape[0], C = shape[1];
|
||||
for (int64_t r=0; r<R; ++r) {{
|
||||
for (int64_t c=0; c<C; ++c) {{
|
||||
long long idx = r*C + c;
|
||||
printf("{pfmt}%s", p[idx], (c+1<C)?", ":"");
|
||||
}}
|
||||
printf("\\n");
|
||||
}}
|
||||
}} else {{
|
||||
// Flattened vector with indices
|
||||
for (long long i=0;i<numel;i++) {{
|
||||
printf("[%lld]={pfmt}%s", i, p[i], (i+1<numel)?", ":"\\n");
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""))
|
||||
|
||||
# Per-output CSV writer into --save-csv-dir
|
||||
csv_write_blocks.append(textwrap.dedent(f"""
|
||||
if (save_csv_dir) {{
|
||||
// Build "DIR/output{oi}_<sanitized name>.csv"
|
||||
@@ -227,9 +194,6 @@ int main(int argc, char **argv) {{
|
||||
OMTensorList *out_list = {entry}(in_list);
|
||||
if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}}
|
||||
|
||||
// ---- Print full outputs ----
|
||||
{"".join(out_blocks)}
|
||||
|
||||
// ---- Optional per-output CSV dump ----
|
||||
{"".join(csv_write_blocks)}
|
||||
|
||||
@@ -240,7 +204,7 @@ int main(int argc, char **argv) {{
|
||||
}}
|
||||
"""
|
||||
|
||||
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None):
|
||||
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None, verbose=True):
|
||||
ins, outs = onnx_io(network_onnx)
|
||||
out_c = out or "runner.c"
|
||||
so_abs = os.path.abspath(network_so)
|
||||
@@ -260,8 +224,9 @@ set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)})
|
||||
target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so)
|
||||
"""
|
||||
pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake)
|
||||
print(f"[OK] Wrote {out_c}")
|
||||
print("[OK] Wrote CMakeLists.txt")
|
||||
if verbose:
|
||||
print(f"[OK] Wrote {out_c}")
|
||||
print("[OK] Wrote CMakeLists.txt")
|
||||
|
||||
if __name__=="__main__":
|
||||
ap=argparse.ArgumentParser()
|
||||
|
||||
BIN
validation/operations/conv/batch_64/conv_batch_64.onnx
Normal file
BIN
validation/operations/conv/batch_64/conv_batch_64.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/simple/conv.onnx
Normal file
BIN
validation/operations/conv/simple/conv.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/with_constant/conv_with_constant.onnx
Normal file
BIN
validation/operations/conv/with_constant/conv_with_constant.onnx
Normal file
Binary file not shown.
@@ -1,12 +1,16 @@
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from colorama import Fore, Style
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count):
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
||||
# Define the arguments, with the possibility to set crossbar size and count
|
||||
args = [
|
||||
network_path,
|
||||
"-o",
|
||||
output_base,
|
||||
"--maccel=PIM",
|
||||
"--EmitPimCodegen",
|
||||
# "--use-experimental-conv-impl=true",
|
||||
@@ -14,16 +18,15 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, cro
|
||||
f"--crossbar-count={crossbar_count}",
|
||||
]
|
||||
|
||||
# Run the executable with the arguments
|
||||
try:
|
||||
result = subprocess.run(
|
||||
run_command_with_reporter(
|
||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=cwd,
|
||||
reporter=reporter,
|
||||
)
|
||||
print(result.stdout + Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(Fore.RED + "Error executing ONNX-MLIR:")
|
||||
print(e.stderr + Style.RESET_ALL)
|
||||
if reporter is None:
|
||||
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||
except subprocess.CalledProcessError:
|
||||
if reporter is None:
|
||||
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
||||
raise
|
||||
|
||||
76
validation/subprocess_utils.py
Normal file
76
validation/subprocess_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import errno
|
||||
import os
|
||||
import pty
|
||||
import selectors
|
||||
import subprocess
|
||||
|
||||
MAX_ERROR_OUTPUT_BYTES = 8192
|
||||
|
||||
|
||||
def _read_chunk(fd, treat_eio_as_eof=False):
|
||||
try:
|
||||
return os.read(fd, 4096)
|
||||
except OSError as exc:
|
||||
if treat_eio_as_eof and exc.errno == errno.EIO:
|
||||
return b""
|
||||
raise
|
||||
|
||||
|
||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
selector = selectors.DefaultSelector()
|
||||
recent_output = bytearray()
|
||||
|
||||
try:
|
||||
selector.register(fd, selectors.EVENT_READ)
|
||||
|
||||
while selector.get_map():
|
||||
for key, _ in selector.select():
|
||||
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
|
||||
if not data:
|
||||
selector.unregister(key.fileobj)
|
||||
os.close(key.fileobj)
|
||||
continue
|
||||
|
||||
reporter._clear()
|
||||
os.write(1, data)
|
||||
reporter._render()
|
||||
recent_output.extend(data)
|
||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||
finally:
|
||||
selector.close()
|
||||
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||
|
||||
|
||||
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
||||
if reporter is None:
|
||||
subprocess.run(cmd, cwd=cwd, check=True)
|
||||
return
|
||||
|
||||
try:
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
except OSError:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
_stream_output(process.stdout.fileno(), process, reporter)
|
||||
return
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
)
|
||||
finally:
|
||||
os.close(slave_fd)
|
||||
|
||||
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
||||
@@ -1,22 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from validate_one import validate_network
|
||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||
|
||||
|
||||
def format_command(cmd):
|
||||
if isinstance(cmd, (list, tuple)):
|
||||
return shlex.join(str(arg) for arg in cmd)
|
||||
return str(cmd)
|
||||
|
||||
|
||||
def format_return_status(returncode):
|
||||
if returncode < 0:
|
||||
signal_num = -returncode
|
||||
try:
|
||||
signal_name = signal.Signals(signal_num).name
|
||||
except ValueError:
|
||||
signal_name = "UNKNOWN"
|
||||
return f"Program terminated by signal {signal_name} ({signal_num})."
|
||||
return f"Program exited with code {returncode}."
|
||||
|
||||
|
||||
def print_validation_error(reporter, rel, exc):
|
||||
reporter.suspend()
|
||||
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
|
||||
file=sys.stderr, flush=True)
|
||||
if isinstance(exc, subprocess.CalledProcessError):
|
||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||
print("Retry command:", file=sys.stderr, flush=True)
|
||||
print(format_command(exc.cmd), file=sys.stderr, flush=True)
|
||||
else:
|
||||
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
||||
print("=" * 72, file=sys.stderr, flush=True)
|
||||
reporter.resume()
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||
ap.add_argument("--raptor-path", required=True, help="Path to the Raptor compiler binary.")
|
||||
ap.add_argument("--onnx-include-dir", required=True, help="Path to OnnxMlirRuntime include directory.")
|
||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||
ap.add_argument("--onnx-include-dir", help="Path to OnnxMlirRuntime include directory.")
|
||||
ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).")
|
||||
ap.add_argument("--simulator-dir", default=None,
|
||||
help="Path to pim-simulator crate root (default: auto-detected relative to script).")
|
||||
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
|
||||
ap.add_argument("--crossbar-size", type=int, default=64)
|
||||
ap.add_argument("--crossbar-count", type=int, default=8)
|
||||
ap.add_argument("--clean", action="store_true",
|
||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||
a = ap.parse_args()
|
||||
|
||||
script_dir = Path(__file__).parent.resolve()
|
||||
@@ -34,32 +70,67 @@ def main():
|
||||
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
|
||||
sys.exit(1)
|
||||
|
||||
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL)
|
||||
if a.clean:
|
||||
removed_count = 0
|
||||
for onnx_path in onnx_files:
|
||||
removed_count += len(clean_workspace_artifacts(onnx_path.parent, onnx_path.stem))
|
||||
print(Style.BRIGHT + f"Removed {removed_count} generated artifact path(s)." + Style.RESET_ALL)
|
||||
sys.exit(0)
|
||||
|
||||
missing_args = []
|
||||
if not a.raptor_path:
|
||||
missing_args.append("--raptor-path")
|
||||
if not a.onnx_include_dir:
|
||||
missing_args.append("--onnx-include-dir")
|
||||
if missing_args:
|
||||
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
|
||||
|
||||
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
|
||||
print(f"Operations root: {operations_dir}")
|
||||
print("=" * 72)
|
||||
|
||||
results = {} # relative_path -> passed
|
||||
for onnx_path in onnx_files:
|
||||
reporter = ProgressReporter(len(onnx_files))
|
||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||
rel = onnx_path.relative_to(operations_dir)
|
||||
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}"
|
||||
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL)
|
||||
try:
|
||||
passed = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
threshold=a.threshold,
|
||||
reporter=reporter,
|
||||
model_index=index,
|
||||
model_total=len(onnx_files),
|
||||
)
|
||||
results[str(rel)] = passed
|
||||
except subprocess.CalledProcessError as exc:
|
||||
results[str(rel)] = False
|
||||
print_validation_error(reporter, rel, exc)
|
||||
except Exception as exc:
|
||||
results[str(rel)] = False
|
||||
print_validation_error(reporter, rel, exc)
|
||||
|
||||
passed = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
threshold=a.threshold,
|
||||
)
|
||||
|
||||
results[str(rel)] = passed
|
||||
reporter.finish()
|
||||
|
||||
# Summary
|
||||
n_passed = sum(results.values())
|
||||
n_passed = sum(1 for passed in results.values() if passed)
|
||||
n_total = len(results)
|
||||
print("\n" + Style.BRIGHT + "=" * 60)
|
||||
print(" Summary")
|
||||
print("=" * 60 + Style.RESET_ALL)
|
||||
status_width = len("Result")
|
||||
path_width = max(len("Operation"), *(len(rel) for rel in results))
|
||||
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
|
||||
|
||||
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
|
||||
print(separator)
|
||||
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
|
||||
print(separator)
|
||||
for rel, passed in results.items():
|
||||
status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL"
|
||||
print(f" {rel}: {status}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL)
|
||||
plain_status = "PASS" if passed else "FAIL"
|
||||
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
|
||||
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
||||
print(f"| {rel.ljust(path_width)} | {status} |")
|
||||
print(separator)
|
||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||
|
||||
sys.exit(0 if n_passed == n_total else 1)
|
||||
|
||||
|
||||
@@ -2,32 +2,148 @@ import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
||||
from raptor import compile_with_raptor
|
||||
from gen_network_runner import gen_network_runner
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir):
|
||||
subprocess.run([raptor_path, network_onnx_path, "--EmitONNXIR"], check=True)
|
||||
subprocess.run([raptor_path, network_onnx_path], check=True)
|
||||
parent = network_onnx_path.parent
|
||||
STAGE_COUNT = 6
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
||||
self.total_models = total_models
|
||||
self.stages_per_model = stages_per_model
|
||||
self.total_steps = max(1, total_models * stages_per_model)
|
||||
self.completed_steps = 0
|
||||
self.current_label = ""
|
||||
self.enabled = True
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
self.suspended = False
|
||||
|
||||
def _clear(self):
|
||||
if self.enabled:
|
||||
sys.stdout.write("\033[2K\r")
|
||||
|
||||
def _render(self):
|
||||
if not self.enabled or self.suspended:
|
||||
return
|
||||
bar_width = 24
|
||||
filled = int(bar_width * self.completed_steps / self.total_steps)
|
||||
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
|
||||
if len(prefix_text) > self.columns:
|
||||
prefix_text = f"{self.completed_steps}/{self.total_steps}"
|
||||
|
||||
label = f" {self.current_label}" if self.current_label else ""
|
||||
available_label_width = max(0, self.columns - len(prefix_text))
|
||||
label = label[:available_label_width]
|
||||
|
||||
if prefix_text.startswith("["):
|
||||
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
|
||||
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
|
||||
else:
|
||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||
|
||||
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
|
||||
sys.stdout.flush()
|
||||
|
||||
def log(self, message="", color=None):
|
||||
if self.enabled:
|
||||
self._clear()
|
||||
if color:
|
||||
print(color + message + Style.RESET_ALL)
|
||||
else:
|
||||
print(message)
|
||||
self._render()
|
||||
|
||||
def set_stage(self, model_index, model_total, model_name, stage_name):
|
||||
self.current_label = f"[{model_index}/{model_total}] {model_name} · {stage_name}"
|
||||
self._render()
|
||||
|
||||
def advance(self):
|
||||
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
|
||||
self._render()
|
||||
|
||||
def suspend(self):
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
|
||||
def resume(self):
|
||||
self.suspended = False
|
||||
|
||||
def finish(self):
|
||||
if self.enabled:
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def run_command(cmd, cwd=None, reporter=None):
|
||||
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
|
||||
|
||||
|
||||
def clean_workspace_artifacts(workspace_dir, model_stem):
|
||||
workspace_dir = Path(workspace_dir)
|
||||
removed_paths = []
|
||||
|
||||
def remove_path(path):
|
||||
if path.is_symlink() or path.is_file():
|
||||
path.unlink(missing_ok=True)
|
||||
removed_paths.append(path)
|
||||
elif path.is_dir():
|
||||
shutil.rmtree(path)
|
||||
removed_paths.append(path)
|
||||
|
||||
for name in GENERATED_DIR_NAMES:
|
||||
remove_path(workspace_dir / name)
|
||||
|
||||
for suffix in (".onnx.mlir", ".so", ".tmp"):
|
||||
remove_path(workspace_dir / f"{model_stem}{suffix}")
|
||||
|
||||
return removed_paths
|
||||
|
||||
|
||||
def print_stage(reporter, model_index, model_total, model_name, title):
|
||||
stage_colors = {
|
||||
"Compile ONNX": Fore.BLUE,
|
||||
"Build Runner": Fore.MAGENTA,
|
||||
"Generate Inputs": Fore.YELLOW,
|
||||
"Run Reference": Fore.GREEN,
|
||||
"Compile PIM": Fore.CYAN,
|
||||
"Run Simulator": Fore.MAGENTA,
|
||||
"Compare Outputs": Fore.YELLOW,
|
||||
}
|
||||
color = stage_colors.get(title, Fore.WHITE)
|
||||
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
|
||||
reporter.set_stage(model_index, model_total, model_name, title)
|
||||
|
||||
|
||||
def print_info(reporter, message):
|
||||
reporter.log(f" {message}")
|
||||
|
||||
|
||||
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
|
||||
stem = network_onnx_path.stem
|
||||
so_path = parent / f"{stem}.so"
|
||||
mlir_path = parent / f"{stem}.onnx.mlir"
|
||||
tmp_path = parent / f"{stem}.tmp"
|
||||
moved_so = runner_dir / so_path.name
|
||||
moved_mlir = raptor_dir / mlir_path.name
|
||||
so_path.rename(moved_so)
|
||||
mlir_path.rename(moved_mlir)
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
return moved_so, moved_mlir
|
||||
onnx_ir_base = raptor_dir / stem
|
||||
runner_base = runner_dir / stem
|
||||
run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"], reporter=reporter)
|
||||
run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter)
|
||||
network_so_path = runner_base.with_suffix(".so")
|
||||
network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
|
||||
onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
|
||||
return network_so_path, network_mlir_path
|
||||
|
||||
|
||||
def build_onnx_runner(source_dir, build_dir):
|
||||
subprocess.run(["cmake", source_dir], cwd=build_dir, check=True)
|
||||
subprocess.run(["cmake", "--build", ".", "-j"], cwd=build_dir, check=True)
|
||||
def build_onnx_runner(source_dir, build_dir, reporter=None):
|
||||
run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter)
|
||||
run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter)
|
||||
return build_dir / "runner"
|
||||
|
||||
|
||||
@@ -41,11 +157,12 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
||||
return ",".join(ranges)
|
||||
|
||||
|
||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges):
|
||||
subprocess.run(
|
||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
||||
run_command(
|
||||
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||
cwd=simulator_dir, check=True
|
||||
cwd=simulator_dir,
|
||||
reporter=reporter,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,66 +181,122 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor):
|
||||
|
||||
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3):
|
||||
all_passed = True
|
||||
rows = []
|
||||
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
|
||||
csv_name = f"output{oi}_{name}.csv"
|
||||
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
|
||||
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
|
||||
passed = max_diff <= threshold
|
||||
status = Fore.GREEN + "[PASS]" if passed else Fore.RED + "[FAIL]"
|
||||
print(f" {name}: max diff = {max_diff:.6e} {status}" + Style.RESET_ALL)
|
||||
rows.append((name, f"{max_diff:.6e}", passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
name_width = max(len("Output"), *(len(name) for name, _, _ in rows))
|
||||
diff_width = max(len("Max diff"), *(len(diff) for _, diff, _ in rows))
|
||||
result_width = len("Result")
|
||||
separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+"
|
||||
|
||||
print(separator)
|
||||
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
|
||||
print(separator)
|
||||
for name, diff_text, passed in rows:
|
||||
status_text = ("PASS" if passed else "FAIL").ljust(result_width)
|
||||
status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL
|
||||
print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |")
|
||||
print(separator)
|
||||
return all_passed
|
||||
|
||||
|
||||
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3):
|
||||
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3,
|
||||
reporter=None, model_index=1, model_total=1):
|
||||
network_onnx_path = Path(network_onnx_path).resolve()
|
||||
raptor_path = Path(raptor_path).resolve()
|
||||
onnx_include_dir = Path(onnx_include_dir).resolve()
|
||||
simulator_dir = Path(simulator_dir).resolve()
|
||||
owns_reporter = reporter is None
|
||||
reporter = reporter or ProgressReporter(model_total)
|
||||
|
||||
workspace_dir = network_onnx_path.parent
|
||||
clean_workspace_artifacts(workspace_dir, network_onnx_path.stem)
|
||||
raptor_dir = workspace_dir / "raptor"
|
||||
runner_dir = workspace_dir / "runner"
|
||||
runner_build_dir = runner_dir / "build"
|
||||
Path.mkdir(raptor_dir, exist_ok=True)
|
||||
Path.mkdir(runner_build_dir, parents=True, exist_ok=True)
|
||||
|
||||
print(Style.BRIGHT + "\nCompiling the onnx network:" + Style.RESET_ALL)
|
||||
network_so_path, network_mlir_path = compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir)
|
||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||
failed_with_exception = False
|
||||
|
||||
print(Style.BRIGHT + "\nGenerating and building the runner:" + Style.RESET_ALL)
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c")
|
||||
runner_path = build_onnx_runner(runner_dir, runner_build_dir)
|
||||
try:
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||
network_so_path, network_mlir_path = compile_onnx_network(
|
||||
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter)
|
||||
print_info(reporter, f"MLIR saved to {network_mlir_path}")
|
||||
print_info(reporter, f"Shared library saved to {network_so_path}")
|
||||
reporter.advance()
|
||||
|
||||
print(Style.BRIGHT + "\nGenerating random inputs:" + Style.RESET_ALL)
|
||||
inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path)
|
||||
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor)
|
||||
flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs")
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
|
||||
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
|
||||
print_info(reporter, f"Runner built at {runner_path}")
|
||||
reporter.advance()
|
||||
|
||||
print(Style.BRIGHT + "\nRunning inference with the runner:" + Style.RESET_ALL)
|
||||
out_dir = workspace_dir / "outputs"
|
||||
Path.mkdir(out_dir, exist_ok=True)
|
||||
run_cmd = [runner_path, *flags]
|
||||
run_cmd += ["--save-csv-dir", f"{out_dir}"]
|
||||
subprocess.run(run_cmd, cwd=runner_build_dir, check=True)
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Generate Inputs")
|
||||
inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path)
|
||||
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor)
|
||||
flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs")
|
||||
print_info(reporter, f"Saved {len(inputs_list)} input file(s) to {workspace_dir / 'inputs'}")
|
||||
reporter.advance()
|
||||
|
||||
print(Style.BRIGHT + "\nCompiling for PIM with Raptor:" + Style.RESET_ALL)
|
||||
compile_with_raptor(network_mlir_path, raptor_path, crossbar_size, crossbar_count)
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Reference")
|
||||
out_dir = workspace_dir / "outputs"
|
||||
Path.mkdir(out_dir, exist_ok=True)
|
||||
run_cmd = [runner_path, *flags]
|
||||
run_cmd += ["--save-csv-dir", f"{out_dir}"]
|
||||
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter)
|
||||
print_info(reporter, f"Reference outputs saved to {out_dir}")
|
||||
reporter.advance()
|
||||
|
||||
print(Style.BRIGHT + "\nRunning PIM simulation:" + Style.RESET_ALL)
|
||||
pim_dir = raptor_dir / "pim"
|
||||
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list)
|
||||
simulation_dir = workspace_dir / "simulation"
|
||||
Path.mkdir(simulation_dir, exist_ok=True)
|
||||
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
|
||||
output_bin_path = simulation_dir / "out.bin"
|
||||
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges)
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||
reporter.advance()
|
||||
|
||||
print(Style.BRIGHT + "\nValidating the results:" + Style.RESET_ALL)
|
||||
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
|
||||
return validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Simulator")
|
||||
pim_dir = raptor_dir / "pim"
|
||||
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list)
|
||||
simulation_dir = workspace_dir / "simulation"
|
||||
Path.mkdir(simulation_dir, exist_ok=True)
|
||||
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
|
||||
output_bin_path = simulation_dir / "out.bin"
|
||||
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter)
|
||||
print_info(reporter, f"Simulator output saved to {output_bin_path}")
|
||||
reporter.advance()
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs")
|
||||
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
|
||||
reporter.suspend()
|
||||
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
||||
reporter.resume()
|
||||
reporter.advance()
|
||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||
return passed
|
||||
except Exception:
|
||||
failed_with_exception = True
|
||||
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
||||
reporter.suspend()
|
||||
raise
|
||||
finally:
|
||||
if not failed_with_exception:
|
||||
reporter.log("=" * 72)
|
||||
if owns_reporter:
|
||||
reporter.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user