Compare commits

...

11 Commits

Author SHA1 Message Date
NiccoloN
568529ea5f fix batched conv
Some checks failed
Validate Operations / config (push) Successful in 1m3s
Validate Operations / build-mlir-cache (push) Successful in 3m40s
Validate Operations / validate (push) Failing after 2m37s
2026-03-20 22:00:46 +01:00
NiccoloN
ca2e1645bb simple convolutions now work :) 2026-03-20 21:17:02 +01:00
NiccoloN
6933804003 constant fold linalg.map (generated from tensor.pad for padding)
refactor pim helpers in PimCommon
2026-03-20 20:51:20 +01:00
NiccoloN
dbe646ac0d fix gemm segfault
Some checks failed
Validate Operations / config (push) Successful in 1m27s
Validate Operations / build-mlir-cache (push) Successful in 2h0m37s
Validate Operations / validate (push) Failing after 3m41s
print exit signals on validation failure
2026-03-20 14:00:16 +01:00
NiccoloN
bb6dcd38a3 replace deprecated "rewriter.create()" pattern
refactor PIM to Pim everywhere except for the accelerator name
2026-03-20 13:30:53 +01:00
NiccoloN
916a09414c add validation artifacts cleanup 2026-03-20 13:15:08 +01:00
NiccoloN
db3f52a647 conv now lowers correctly down to bufferized pim 2026-03-20 12:55:09 +01:00
NiccoloN
6e1de865bb add constant folding and verification pass for pim host operations
better validation scripts output
big refactors
2026-03-20 12:08:12 +01:00
NiccoloN
4e50e056e3 replace old convolution support in spatial (WIP)
Some checks failed
Validate Operations / config (push) Successful in 1m46s
Validate Operations / build-mlir-cache (push) Successful in 3m25s
Validate Operations / validate (push) Failing after 2m42s
2026-03-13 17:46:10 +01:00
NiccoloN
771b44a2ed fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 53s
Validate Operations / validate (push) Has been cancelled
Validate Operations / build-mlir-cache (push) Has been cancelled
2026-03-11 15:41:26 +01:00
NiccoloN
7ce1d2b34d fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 52s
Validate Operations / build-mlir-cache (push) Successful in 2h0m25s
Validate Operations / validate (push) Failing after 2m26s
2026-03-10 16:12:26 +01:00
79 changed files with 3042 additions and 3051 deletions

View File

@@ -12,29 +12,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Free disk space
if: runner.os == 'Linux'
run: |
df -h
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y '^llvm-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y '^mongodb-.*'
sudo apt-get remove -y '^mysql-.*'
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
sudo apt-get autoremove -y
sudo apt-get clean
df -h
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
sudo docker system prune --all --volumes --force
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
df -h
- name: Cache MLIR build - name: Cache MLIR build
id: cache-mlir id: cache-mlir
uses: actions/cache@v4 uses: actions/cache@v4

View File

@@ -29,33 +29,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Free disk space
if: runner.os == 'Linux'
run: |
df -h
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y '^llvm-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y '^mongodb-.*'
sudo apt-get remove -y '^mysql-.*'
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
sudo apt-get autoremove -y
sudo apt-get clean
df -h
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
sudo docker system prune --all --volumes --force
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
df -h
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
submodules: recursive submodules: recursive
github-server-url: https://chef.heaplab.deib.polimi.it/git
- name: Install system dependencies - name: Install system dependencies
run: | run: |

2
.gitignore vendored
View File

@@ -1,2 +1,4 @@
.idea .idea
.claude
AGENTS.md
build build

View File

@@ -20,6 +20,8 @@ add_onnx_mlir_library(OMPIMAccel
Pass/CountInstructionPass.cpp Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp Pass/MessagePass.cpp
Pass/PimConstantFoldingPass.cpp
Pass/PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -41,6 +43,7 @@ add_onnx_mlir_library(OMPIMAccel
PimOps PimOps
OMONNXToSpatial OMONNXToSpatial
OMSpatialToGraphviz OMSpatialToGraphviz
OMSpatialToPIM OMSpatialToPim
OMPIMCommon OMPimCommon
MLIRTensorInferTypeOpInterfaceImpl
) )

View File

@@ -1,5 +1,5 @@
add_onnx_mlir_library(OMPIMCommon add_onnx_mlir_library(OMPimCommon
PIMCommon.cpp PimCommon.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -1,93 +0,0 @@
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string dialectsDir = getOutputDir() + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
} // namespace onnx_mlir

View File

@@ -1,24 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/StringRef.h"
#include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
namespace onnx_mlir {
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
} // namespace onnx_mlir

View File

@@ -0,0 +1,239 @@
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,51 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir

View File

@@ -13,11 +13,11 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
@@ -49,8 +49,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others // Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants; SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!getGlobalOp->hasAttr("weightAlways")) { if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp); auto iter = globalConstants.find(globalMemrefOp);
if (iter == globalConstants.end()) if (iter == globalConstants.end())
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
@@ -81,7 +81,7 @@ MemEntry PimMemory::getMemEntry(mlir::Value value) const {
return iter->second; return iter->second;
} }
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second; return deviceMem.try_emplace(id, memEntriesMap).first->second;
} }
@@ -112,10 +112,33 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
} }
value = source; value = source;
} }
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
}
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
}
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
}
else else
break; break;
} }
return memEntriesMap.at(value).address + offset;
auto iter = memEntriesMap.find(value);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + offset;
} }
json::Object PimCodeGen::createEmptyOffset() { json::Object PimCodeGen::createEmptyOffset() {
@@ -348,6 +371,52 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
} }
} }
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getData());
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) { size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4) if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape"); assert(false && "Unsupported matrix shape");
@@ -378,9 +447,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (getGlobalOp->hasAttr("weightAlways")) if (hasWeightAlways(getGlobalOp))
return; return;
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) if (!globalOp)
return; return;
auto initialValue = globalOp.getInitialValue(); auto initialValue = globalOp.getInitialValue();
@@ -416,7 +485,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0; size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) { for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op)) if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue; continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
@@ -435,6 +504,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false); coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op)) else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op)) else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp); coreCodeGen.codeGenVAddOp(vaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op)) else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
@@ -475,7 +546,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
continue; continue;
} }
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr()); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) { if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex)); coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++; weightIndex++;
@@ -589,9 +660,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
} }
} }
auto funcOps = moduleOp.getOps<func::FuncOp>(); auto entryFunc = getPimEntryFunc(moduleOp);
assert(!funcOps.empty() && "No function found in the module"); if (failed(entryFunc))
auto funcOp = *funcOps.begin(); return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory; PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp); memory.hostMem.allocateHost(moduleOp, funcOp);

View File

@@ -5,7 +5,7 @@
#include "Common/ValueMap.hpp" #include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -49,7 +49,7 @@ public:
PimAcceleratorMemory() PimAcceleratorMemory()
: hostMem(memEntriesMap) {} : hostMem(memEntriesMap) {}
PimMemory getOrCreateDeviceMem(size_t id); PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value) const; size_t getValueAddress(mlir::Value value) const;
}; };
@@ -95,6 +95,7 @@ public:
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const; void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
}; };
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);

View File

@@ -25,7 +25,6 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen; extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl; extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> exportCrossbarWeights;
extern llvm::cl::opt<size_t> crossbarSize; extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore; extern llvm::cl::opt<size_t> crossbarCountInCore;

View File

@@ -2,7 +2,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
@@ -34,7 +34,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPIMPass()); pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim")); pm.addPass(createMessagePass("Spatial lowered to Pim"));
} }
@@ -46,6 +46,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimHostVerificationPass());
pm.addPass(createMessagePass("Pim host verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted")); pm.addPass(createMessagePass("Pim json code emitted"));

View File

@@ -1,3 +1,3 @@
add_subdirectory(ONNXToSpatial) add_subdirectory(ONNXToSpatial)
add_subdirectory(SpatialToGraphviz) add_subdirectory(SpatialToGraphviz)
add_subdirectory(SpatialToPIM) add_subdirectory(SpatialToPim)

View File

@@ -5,17 +5,13 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.cpp Math/Gemm.cpp
Math/Conv.cpp Math/Conv.cpp
Math/ExperimentalConv.cpp
Math/ExperimentalGemm.cpp
NN/Pooling.cpp NN/Pooling.cpp
NN/ExperimentalPooling.cpp
NN/ReduceMean.cpp NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp Tensor/ONNXConcatToTensorConcat.cpp
Tensor/RemoveUnusedHelperOps.cpp Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp Utils/AnnotateReplication.cpp
ONNXToSpatialPass.hpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp ONNXToSpatialCommon.cpp
@@ -27,7 +23,7 @@ add_onnx_mlir_library(OMONNXToSpatial
OMPimCompilerOptions OMPimCompilerOptions
OMONNXOps OMONNXOps
SpatialOps SpatialOps
OMPIMCommon OMPimCommon
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH} ${PIM_INCLUDE_PATH}

View File

@@ -1,583 +1,273 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <cstddef> #include <cassert>
#include <memory>
#include <unordered_map>
#include <vector>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
using namespace std;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
// NOTE: struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
// This might be useful to re-implement this considering for loops. using OpConversionPattern::OpConversionPattern;
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
/** LogicalResult matchAndRewrite(ONNXConvOp convOp,
* @brief A momentary representation of a core, to be used within the tiling of ONNXConvOpAdaptor convOpAdaptor,
* a convolution operation. ConversionPatternRewriter& rewriter) const override;
*/
class Core {
public:
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
: coreId(coreId), rewriter(rewriter) {}
/**
* @brief Add a MVM operation to the core.
*
* @param inputTile The input tile to the MVM operation.
* @param xbarIndex The index of the crossbar weight to use.
* @param outputTileId The id of the output tile.
* @param mvmOutType The result's shape.
* @return Value The result of the MVM operation.
*/
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc();
// Move the insertion point to the end of the block.
rewriter.setInsertionPointToEnd(block.get());
// Add the inputTile to the block arguments, and to the operands.
Value operand = operandMap.lookupOrNull(inputTile);
if (not operand) {
operand = block->addArgument(inputTile.getType(), loc);
operands.push_back(inputTile);
operandMap.map(inputTile, operand);
}
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
// is correct.
// Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple
// MVM operations for the same outputTile.
auto lastMVM = outputTileToMVM.find(outputTileId);
// If an entry for this outputTile already exists, apply reduction.
if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
}
outputTileToMVM[outputTileId] = result;
return result;
}
/**
* @brief Mark a result as remappable, and return a shared pointer to it.
*
* This function marks a result as remappable, and returns a shared pointer to
* it. We need to keep track of these values to generate the YieldOp at a
* later stage.
*
* @param result A result to track, for later remapping.
* @return shared_ptr<Value> A shared pointer to the result.
*/
shared_ptr<Value> makeResultRemappable(Value result) {
// Verify that the result is present in the block.
assert(result.getDefiningOp()->getBlock() == block.get());
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
resultsToRemap.push_back(remappableResult);
results.push_back(result);
return remappableResult;
}
/**
* @brief Add a remappable operand to the core, to merge partial results
* inter-core.
*
* @param remappableOperand The operand to add.
* @return Value The block argument representing the operand.
*/
Value addRemappableOperand(std::shared_ptr<Value> operand) {
// Check that the operand is not already there.
assert(not operandMap.contains(*operand));
Value argument = block->addArgument(operand->getType(), operand->getLoc());
remappableOperands.push_back(operand);
return argument;
}
/**
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
*
* @param loc The location of the operation.
* @return spatial::SpatWeightedCompute
*/
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results.
SmallVector<Type> resultTypes;
for (const auto& value : results)
resultTypes.push_back(value.getType());
// Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp.
Block* releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results.
rewriter.setInsertionPointToEnd(releasedBlock);
rewriter.create<spatial::SpatYieldOp>(loc, results);
return wcomputeOp;
}
/**
* @brief Remap the results to the WComputeOp results.
*/
void remapResults() {
// Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++)
*resultsToRemap[i] = wcomputeOp.getResult(i);
}
void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands)
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
// Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
}
size_t addXbarWeight(Value weight) {
assert(!isXbarsFull());
xbarWeights.push_back(weight);
return xbarWeights.size() - 1;
}
bool isXbarsFull() {
assert(xbarWeights.size() <= crossbarCountInCore);
return xbarWeights.size() == crossbarCountInCore;
}
bool isCoreEmpty() { return block->empty(); }
void dump() {
// Print the coreId
llvm::outs() << "Core " << coreId << ":\n";
// Print the weights
llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights)
weight.dump();
// Print the operands
llvm::outs() << "Operands:\n";
for (auto operand : operands)
llvm::outs() << operand << "\n";
// Dump the body block
for (auto& op : block->getOperations())
op.dump();
// Print the results
llvm::outs() << "Results:\n";
for (auto result : results)
llvm::outs() << result << "\n";
}
const size_t coreId;
private:
ConversionPatternRewriter& rewriter;
// Should these be set<Value> instead? But I need to keep the order
vector<Value> operands;
vector<std::shared_ptr<Value>> remappableOperands;
vector<Value> results;
vector<std::shared_ptr<Value>> resultsToRemap;
// Maps from input tiles to the block operand
IRMapping operandMap;
// Map from outputTileId to MVM operation producing it
unordered_map<size_t, Value> outputTileToMVM;
vector<Value> xbarWeights;
unique_ptr<mlir::Block> block = make_unique<Block>();
spatial::SpatWeightedCompute wcomputeOp;
}; };
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> { } // namespace
ONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct Producer_t { LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value value; ONNXConvOpAdaptor convOpAdaptor,
shared_ptr<Core> core; ConversionPatternRewriter& rewriter) const {
}; Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
LogicalResult auto xType = cast<RankedTensorType>(x.getType());
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { auto wType = cast<RankedTensorType>(w.getType());
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType()); auto outType = cast<RankedTensorType>(convOp.getY().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y; assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); assert("Only support 2D convolution" && xType.getRank() == 4);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); // We need to understand what is group
if (padUnpackError.has_value()) assert("Only support group=1" && convOp.getGroup() == 1);
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
// TODO: Pad value at beginning and end of each dimension could be const int64_t batchSize = xType.getDimSize(0);
// different. We should handle this case. const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// MapOperations mapOperation = MapOperations::None; // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
// auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
// // If we have just one user, and it is an activation funcion (or more in
// // general a mapping operation) just inline it in the computeOps
// auto firstUserOp = *conv->getUsers().begin();
// if (conv->hasOneUse()) {
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
//
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
// return rewriter.notifyMatchFailure(
// conv, "Softmax not supported as activation for convolutions.");
// }
// }
size_t input_h = GET_IMAGE_HEIGHT(xShape); const auto stridesAttr = convOp.getStrides();
size_t input_w = GET_IMAGE_WIDTH(xShape); const auto dilationsAttr = convOp.getDilations();
size_t output_h = GET_IMAGE_HEIGHT(yShape); const auto padsAttr = convOp.getPads();
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
Location loc = conv.getLoc(); const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); int64_t padHeightBegin = 0;
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; int64_t padHeightEnd = 0;
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); int64_t padWidthBegin = 0;
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; int64_t padWidthEnd = 0;
// Tile the input tensor if (padsAttr) {
// Input tiles need to be indexed by: padHeightBegin = getI64(*padsAttr, 0);
// a. Channel Tile padWidthBegin = getI64(*padsAttr, 1);
// b. Pixel `x` position padHeightEnd = getI64(*padsAttr, 2);
// c. Pixel `y` position padWidthEnd = getI64(*padsAttr, 3);
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
// Tile the weight tensor
// Weight tiles need to be indexed by:
// a. Filter Tile
// b. Channel Tile
// c. Kernel `x` position
// d. Kernel `y` position
// For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0)
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0)
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] =
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
}
}
}
}
/* Distribute the computation among many compute cores
* Try to compute in-core the computation for each output tile, and reduce
* over as few cores as possible
*/
// Tile the output tensor
// Output tiles need to be indexed by:
// a. Filter Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
replicationFactor = 1;
else
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
// producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
outputTileCount,
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores
size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
curCores[i] = make_shared<Core>(coreId++, rewriter);
vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
size_t out_y = 0;
// I use `replicationFactor` cores. I divide the input_w into
// `replicationFactor` slices, and each slice is distributed to a
// core. `coreIndex` is the index of the core that will be used
// for this slice
size_t coreIndex = in_x / replicationSliceSize;
assert(coreIndex < replicationFactor);
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
out_y++;
continue;
}
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
out_y++;
}
out_x++;
}
// Computations for these crossbars are done, check if the cores
// crossbars are fully used. If full, swap with new core
for (size_t i = 0; i < replicationFactor; i++) {
if (curCores[i]->isXbarsFull()) {
cores.emplace_back(std::move(curCores[i]));
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
}
}
}
}
}
for (auto& curCore : curCores)
if (curCore->isCoreEmpty() == false)
cores.emplace_back(std::move(curCore));
curCores.clear();
// Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
for (size_t out_x = 0; out_x < output_w; out_x++) {
for (size_t out_y = 0; out_y < output_h; out_y++) {
// First, check if some producers are within the same core. If this is
// true, `Core::addMVM` have already done the reduction within-core.
// This means that we only need to consider the last producer for that
// core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y])
withinCoreReducedProducers[producer.core->coreId] = producer;
// Now, we need to apply inter-core reduction
// Base case with one producer
if (withinCoreReducedProducers.size() == 1) {
// TODO: Add the bias and apply mapping (if present)
auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
continue;
}
// TODO: This is a linear reduction, not a tree reduction. We can do
// better: a tree reduction would make more computations happen in
// parallel.
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
auto it = withinCoreReducedProducers.begin();
it++;
while (it != withinCoreReducedProducers.end()) {
Producer_t curProducer = it->second;
shared_ptr<Core> core1;
shared_ptr<Core> core2;
Value core1Value;
Value core2Value;
auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId
&& "We should have already applied within-core reduction, how "
"could we have same cores here?");
// Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) {
core1 = curProducer.core;
core1Value = curProducer.value;
core2 = lastProducer.core;
core2Value = lastProducer.value;
} }
else { else {
core1 = lastProducer.core; // Compute padding from auto_pad attribute
core1Value = lastProducer.value; const auto autoPad = convOp.getAutoPad();
core2 = curProducer.core; if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
core2Value = curProducer.value; const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
// "NOTSET" or "VALID" -> all pads stay 0
} }
auto newCoreRes = core1->makeResultRemappable(core1Value); // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); // A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
rewriter.setInsertionPointAfterValue(core2Value); auto elemType = xType.getElementType();
Value vaddRes = rewriter.create<spatial::SpatVAddOp>( auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg); auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
lastProducer = {vaddRes, core2}; // Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
it++; // Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC;
if (hasB) {
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
gemmC = tensor::ExpandShapeOp::create(rewriter,
loc,
biasType,
b,
SmallVector<ReassociationIndices> {
{0, 1}
});
}
else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
auto im2colComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
auto* im2colBlock = new Block();
im2colBlock->addArgument(x.getType(), loc);
im2colComputeOp.getBody().push_back(im2colBlock);
rewriter.setInsertionPointToStart(im2colBlock);
Value paddedInput = im2colBlock->getArgument(0);
// Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult();
} }
// TODO: Add the bias and apply mapping (if present) // Build im2col [numPatches, patchSize]:
// For each batch/output position (n, oh, ow), extract the patch from x
SmallVector<Value> im2colRows;
im2colRows.reserve(numPatches);
for (int64_t n = 0; n < batchSize; n++) {
for (int64_t oh = 0; oh < outHeight; oh++) {
for (int64_t ow = 0; ow < outWidth; ow++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(oh * strideHeight),
rewriter.getIndexAttr(ow * strideWidth)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
// Use last producer as the final result // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value); Value row = tensor::CollapseShapeOp::create(rewriter,
outputTiles[outTile][out_x][out_y] = reducedValue; loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
im2colRows.push_back(row);
} }
} }
} }
// Now, we need to turn the cores into a spatial::SpatWeightedCompute. // Concatenate all rows: [numPatches, patchSize]
rewriter.setInsertionPointAfter(conv); Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
spatial::SpatWeightedCompute lastWComputeOp; spatial::SpatYieldOp::create(rewriter, loc, im2col);
for (auto& core : cores) {
lastWComputeOp = core->createWComputeOp(loc);
core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp);
}
for (auto& core : cores) rewriter.setInsertionPointAfter(im2colComputeOp);
core->addRemappedOperands();
// Set the insertion point after the last WComputeOp. // Gemm: A @ B + C = im2col @ W^T + b
rewriter.setInsertionPointAfter(lastWComputeOp); // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
SmallVector<Value> tilesToConcat; auto gemmOp = ONNXGemmOp::create(rewriter,
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); loc,
for (size_t outX = 0; outX < output_h; outX++) gemmOutType,
for (size_t outY = 0; outY < output_w; outY++) im2colComputeOp.getResult(0),
for (size_t outTile = 0; outTile < outputTileCount; outTile++) wTrans,
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
Value gemmOut = gemmOp.getY();
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat); auto collectComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
// Value outputImage = auto* collectBlock = new Block();
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); collectBlock->addArgument(gemmOut.getType(), loc);
collectComputeOp.getBody().push_back(collectBlock);
rewriter.setInsertionPointToStart(collectBlock);
// If no mapping (activation) was applied, just replace ConvOp auto gemmOutArg = collectBlock->getArguments().front();
// if (mapOperation == MapOperations::None) {
// rewriter.replaceOp(conv, outputImage);
// } else {
// // If mapping was applied, erase ConvOp and replace the mapping op
// rewriter.eraseOp(conv);
// rewriter.replaceOp(firstUserOp, outputImage);
// }
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut]
// -> [1, numChannelsOut, outHeight, outWidth]
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOutArg,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
});
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
return success(); return success();
} }
};
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
patterns.insert<ONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,400 +0,0 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cstddef>
#include <unistd.h>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A pattern to tile the convolution operation into a series of compute
* units, each one of which applies filters to a subset of the input
* tensor. Results are also reduced and concatenated to form the final
* output tensor.
*/
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1
&& "Batch size must be 1"
"for convolution.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
// TODO: Address bias addition.
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), conv->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
auto offsets = extractSliceOp.getStaticOffsets();
xKerPos.push_back(offsets[2]);
yKerPos.push_back(offsets[3]);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
// If the conv's user is a ReLU...
if (conv->hasOneUse()) {
Operation* user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution.
rewriter.eraseOp(conv);
return success();
}
}
// Return the final output.
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success();
}
};
/**
* @brief Populate the tiling pattern for a convolution operation.
*
* @param patterns The pattern set to populate.
* @param ctx The MLIR context.
*/
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,365 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdlib>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
// TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
size_t kernelWidth = 1;
size_t kernelHeight = 1;
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), gemmOp->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
xKerPos.push_back(0);
yKerPos.push_back(0);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
// Return the final output.
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success();
}
};
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -2,18 +2,15 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cassert> #include <cassert>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -22,21 +19,48 @@
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> { struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
GemmToManyGemv(MLIRContext* ctx) using OpConversionPattern::OpConversionPattern;
: OpConversionPattern(ctx, 2) {}
LogicalResult LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
private:
static Value resolveONNXExpOpFromUseChain(Value startValue);
static LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc);
};
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc(); Location loc = gemmOp.getLoc();
Value a = adaptor.getA(); Value a = gemmOpAdaptor.getA();
Value b = adaptor.getB(); Value b = gemmOpAdaptor.getB();
Value c = adaptor.getC(); Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !adaptor.getTransA()); assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
@@ -67,7 +91,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult(); auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c; Value cSlice = c;
if (hasC) { if (hasC) {
@@ -76,13 +100,14 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult(); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
} }
else else
assert("C should be a vector" && isVectorShape(getTensorShape(c))); assert("C should be a vector" && isVectorShape(getTensorShape(c)));
} }
auto gemvOp = rewriter.create<ONNXGemmOp>(loc, auto gemvOp = ONNXGemmOp::create(rewriter,
loc,
outRowType, outRowType,
aSlice, aSlice,
b, b,
@@ -95,7 +120,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
} }
auto concatComputeOp = auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps); spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
auto* concatBlock = new Block(); auto* concatBlock = new Block();
for (auto gemvOp : gemvOps) for (auto gemvOp : gemvOps)
@@ -104,30 +129,26 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
rewriter.setInsertionPointToStart(concatBlock); rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments(); auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs); auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
} }
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> { LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
GemvToSpatialCompute(MLIRContext* ctx) ONNXGemmOpAdaptor gemmOpAdaptor,
: OpConversionPattern(ctx, 1) {} ConversionPatternRewriter& rewriter) const {
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location gemmLoc = gemmOp.getLoc(); Location gemmLoc = gemmOp.getLoc();
Value a = adaptor.getA(); Value a = gemmOpAdaptor.getA();
Value b = adaptor.getB(); Value b = gemmOpAdaptor.getB();
Value c = adaptor.getC(); Value c = gemmOpAdaptor.getC();
Value out = gemmOp.getY(); Value out = gemmOp.getY();
float alpha = adaptor.getAlpha().convertToFloat(); float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
float beta = adaptor.getBeta().convertToFloat(); float beta = gemmOpAdaptor.getBeta().convertToFloat();
bool transA = adaptor.getTransA(); bool transA = gemmOpAdaptor.getTransA();
bool transB = adaptor.getTransB(); bool transB = gemmOpAdaptor.getTransB();
auto aType = cast<RankedTensorType>(a.getType()); auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType()); auto bType = cast<RankedTensorType>(b.getType());
@@ -143,32 +164,32 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape())) if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv // Not a gemv
return failure(); return failure();
if (transA) { if (transA) {
auto aShape = aType.getShape(); auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
} }
if (transB) { if (transB) {
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
} }
if (alpha != 1.0f) { if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType()); auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue); auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor); a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
} }
if (hasC && beta != 1.0f) { if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType()); auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue); auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor); c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
} }
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
@@ -215,7 +236,7 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
auto* computeBlock = new Block(); auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId]) for (auto aHSlice : aHSlices[coreId])
@@ -228,11 +249,11 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
vmmOutputs.reserve(computeArgs.size()); vmmOutputs.reserve(computeArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
vmmOutputs.push_back( vmmOutputs.push_back(
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
partialResults.push_back(computeOp.getResult(0)); partialResults.push_back(computeOp.getResult(0));
@@ -244,7 +265,7 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
} }
auto reduceComputeOp = auto reduceComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults); spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
auto* reduceBlock = new Block(); auto* reduceBlock = new Block();
for (auto partialResult : partialResults) for (auto partialResult : partialResults)
@@ -254,14 +275,14 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
auto blockArgs = reduceBlock->getArguments(); auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice); spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp); rewriter.setInsertionPointAfter(reduceComputeOp);
outHSlices.push_back(reduceComputeOp.getResult(0)); outHSlices.push_back(reduceComputeOp.getResult(0));
} }
auto concatComputeOp = auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices); spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
auto* concatBlock = new Block(); auto* concatBlock = new Block();
for (auto outHSlice : outHSlices) for (auto outHSlice : outHSlices)
@@ -270,24 +291,14 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
rewriter.setInsertionPointToStart(concatBlock); rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments(); auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs); auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
} }
private: Value GemvToSpatialCompute::resolveONNXExpOpFromUseChain(Value startValue) {
/**
* Resolves the ONNXExpOp from the use chain of the given start value.
*
* This function traverses the use chain of the start value until it finds an
* ONNXExpOp. It returns the value of the ONNXExpOp.
*
* @param startValue The starting value of the use chain.
* @return The value of the ONNXExpOp found in the use chain.
*/
static Value resolveONNXExpOpFromUseChain(Value startValue) {
Value walker = startValue; Value walker = startValue;
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) { while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
@@ -306,18 +317,12 @@ private:
return walker; return walker;
} }
// Softmax is a special case, as it requires another reduction after the LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
// first one. In the cores, `applyReducePattern` already applied
// f(x) = exp(x) to each tile. This mean that now we just need to
// reduce-sum these tiles, and then divide each tile by the reduced sum,
// which is propagated back to the cores via a broadcast channel.
LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel, Value& softmaxChannel,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
SpatialReducer& reducer, SpatialReducer& reducer,
ONNXGemmOp& gemmOp, ONNXGemmOp& gemmOp,
Location& loc) const { Location& loc) {
// TODO: Check case with one compute op // TODO: Check case with one compute op
// Cast vector of Value into vector of ComputeOp // Cast vector of Value into vector of ComputeOp
@@ -331,9 +336,9 @@ private:
reducer.applyReducePattern( reducer.applyReducePattern(
softmaxOpsToReduce, softmaxOpsToReduce,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); }, [&](Value a, Value b) { return spatial::SpatVAddOp::create(rewriter, loc, scalarTensorType, a, b); },
/* preprocess = */ /* preprocess = */
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); }, [&](Value a) { return spatial::SpatSumOp::create(rewriter, loc, scalarTensorType, a); },
[&](Value softmaxDivisor) { [&](Value softmaxDivisor) {
// Signal that this is the compute with the softmax divisor // Signal that this is the compute with the softmax divisor
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp()); auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
@@ -341,7 +346,7 @@ private:
// Broadcast the divisor to all the cores // Broadcast the divisor to all the cores
rewriter.setInsertionPointAfterValue(softmaxDivisor); rewriter.setInsertionPointAfterValue(softmaxDivisor);
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor); spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, softmaxChannel, softmaxDivisor);
/* /*
* softmaxDividend = onnx.exp (...) * softmaxDividend = onnx.exp (...)
@@ -391,7 +396,7 @@ private:
} }
else { else {
rewriter.setInsertionPoint(yieldOp); rewriter.setInsertionPoint(yieldOp);
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel); divisor = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, loc, scalarTensorType, softmaxChannel);
} }
// Walk the chain of operations until we find the ONNXExpOp: this is // Walk the chain of operations until we find the ONNXExpOp: this is
@@ -401,7 +406,7 @@ private:
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second)); Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
rewriter.setInsertionPoint(yieldOp); rewriter.setInsertionPoint(yieldOp);
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor); Value newOutputTile = spatial::SpatVSDivOp::create(rewriter, loc, oldOutputTile.getType(), oldOutputTile, divisor);
auto yieldOperandNum = yieldOp->getNumOperands(); auto yieldOperandNum = yieldOp->getNumOperands();
yieldOp->insertOperands(yieldOperandNum, newOutputTile); yieldOp->insertOperands(yieldOperandNum, newOutputTile);
@@ -410,7 +415,6 @@ private:
return success(); return success();
} }
};
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx); patterns.insert<GemmToManyGemv>(ctx);

View File

@@ -1,300 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename PoolOp>
bool hasPostProcessExperimentalPoolingWindow() {
return false;
}
template <>
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename ReductionOp>
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
if (inputTiles.size() == 1)
return inputTiles[0];
if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
}
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
// Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat)
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
// Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
// For each argument of the tensor.ConcatOp, resolve the input tiles.
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile;
}
}
}
// Prepare the shape of the compute's output.
ldiv_t itc = tileCount;
SmallVector<Type> outputTileTypes;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
}
}
}
// Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x)
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
inputTilesList.push_back(inputTiles[it][y][x]);
}
// Create a single compute to calculate the output.
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
}
}
}
// Begin writing in the block.
rewriter.setInsertionPointToStart(block);
// Go through all pooling blocks.
SmallVector<Value> outputTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
size_t start_x = x * stride_x;
size_t start_y = y * stride_y;
size_t end_x = std::min(start_x + krn_w, input_w);
size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky)
for (size_t kx = start_x; kx < end_x; ++kx)
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp);
auto divisorValue =
rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult =
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
}
outputTiles.push_back(reduceResult);
}
}
}
// Create a YieldOp to return the output tiles.
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
// Set the rewrite cursor right after the computeOp.
rewriter.setInsertionPointAfter(computeOp);
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
}
}
}
// We'll now create spat.img.concat ops to concatenate the output tiles.
SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y)
for (size_t x = 0; x < output_w; ++x)
imgConcatTiles.push_back(computeOutput[it][y][x]);
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long) tilingSize,
/* 2 */ (long) output_w,
/* 3 */ (long) output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
}
// Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor);
return success();
}
};
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
ctx);
}
} // namespace onnx_mlir

View File

@@ -15,7 +15,7 @@
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
@@ -26,8 +26,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce, Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce, std::function<Value(const Value&, const Value&)> reduce,
@@ -113,15 +111,15 @@ Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
// 1. Add a channel before the first computeOp // 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute); rewriter.setInsertionPoint(firstCompute);
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType); auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Add a sendOp after the first value // 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue); rewriter.setInsertionPointAfterValue(firstValue);
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue); spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
// 3. Add a receiveOp after the second value // 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue); rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel); auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value // 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue); rewriter.setInsertionPointAfterValue(receivedValue);
@@ -190,13 +188,14 @@ Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rew
// directly under func.func (i.e. alongside ComputeOps) // directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp()); auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, auto divisorValue = spatial::SpatConstantOp::create(rewriter,
loc,
scalarTensor, scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber), rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true)); /* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide); rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue); return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
} }
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp> template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
@@ -225,12 +224,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Location loc = poolOp.getLoc(); Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape); size_t input_h = getImageHeight(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape); size_t input_w = getImageWidth(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape); size_t output_h = getImageHeight(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape); size_t output_w = getImageWidth(yShape);
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize; size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
// 1: Tile the input tensor // 1: Tile the input tensor
// Input tiles need to be indexed by: // Input tiles need to be indexed by:
@@ -259,7 +258,8 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) { if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc(); Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc, auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
tileLoc,
extractSliceOp.getResultType(), extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(), /* xbarWeights =*/ValueRange(),
extractSliceOp.getResult()); extractSliceOp.getResult());
@@ -269,7 +269,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc); auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock); rewriter.setInsertionPointToStart(tempComputeOpBlock);
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg); spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp); rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0); inputTiles[t][x][y] = tempComputeOp.getResult(0);
} }
@@ -358,7 +358,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Value reducedWithinCompute = applyReducePatternNew( Value reducedWithinCompute = applyReducePatternNew(
valuesToPool, valuesToPool,
rewriter, rewriter,
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); }, [&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
nullptr, nullptr,
postProcessFn); postProcessFn);
@@ -371,16 +371,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Create a new channel before the computeOp // Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced); rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel = auto reduceChannel =
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext())); spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel // Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute); rewriter.setInsertionPointAfterValue(reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute); spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp // Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced); rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue = auto receivedValue =
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel); spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue; outputTiles[outTile][outX][outY] = receivedValue;
} }

View File

@@ -63,7 +63,8 @@ struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV1
/*elementType=*/inputTensorType.getElementType()); /*elementType=*/inputTensorType.getElementType());
// Create the ONNXAveragePoolOp. // Create the ONNXAveragePoolOp.
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(), auto averagePool = ONNXAveragePoolOp::create(rewriter,
reduceMean.getLoc(),
resultType, resultType,
inputTensor, inputTensor,
/*auto_pad=*/"NOTSET", /*auto_pad=*/"NOTSET",

View File

@@ -13,9 +13,7 @@ def onnxToArithConstantOp : Pat<
(Arith_ConstantOp $value) (Arith_ConstantOp $value)
>; >;
//===----------------------------------------------------------------------===//
// ONNXMatMulOp to ONNXGemmOp patterns // ONNXMatMulOp to ONNXGemmOp patterns
//===----------------------------------------------------------------------===//
def matMulAddToGemmPattern : Pat< def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
@@ -31,7 +29,7 @@ def matMulToGemmPattern : Pat<
(ONNXMatMulOp:$matmulres $A, $B), (ONNXMatMulOp:$matmulres $A, $B),
( (
ONNXGemmOp $A, $B, ONNXGemmOp $A, $B,
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">), /* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
@@ -39,9 +37,7 @@ def matMulToGemmPattern : Pat<
) )
>; >;
//===----------------------------------------------------------------------===//
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
//===----------------------------------------------------------------------===//
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single // This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias. // ONNXConvOp with a bias.
@@ -55,9 +51,7 @@ def convAddToConvWithBiasPatternRight : Pat<
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>; >;
//===----------------------------------------------------------------------===//
// Operation to ignore (i.e. remove) // Operation to ignore (i.e. remove)
//===----------------------------------------------------------------------===//
def replaceWithOperationOfValue : NativeCodeCall<"$0">; def replaceWithOperationOfValue : NativeCodeCall<"$0">;

View File

@@ -47,7 +47,7 @@ SmallVector<Value> sliceTensor(
if (i == numSlices - 1 && lastSliceSize != 0) if (i == numSlices - 1 && lastSliceSize != 0)
sizes[axis] = rewriter.getIndexAttr(lastSliceSize); sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides); Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
slices.push_back(slice); slices.push_back(slice);
} }
@@ -100,11 +100,11 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
int64_t shape[2] = {1, length}; int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType); Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero); SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult(); auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
return rewriter.create<tensor::SplatOp>(loc, type, elementValue); return tensor::SplatOp::create(rewriter, loc, type, elementValue);
} }
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) { Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
@@ -122,7 +122,7 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value a = (*currTensors)[i]; Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1]; Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b); rewriter.setInsertionPointAfterValue(b);
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue); nextTensors->push_back(addedValue);
} }
if (currTensors->size() % 2 == 1) if (currTensors->size() % 2 == 1)
@@ -137,10 +137,10 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) { Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
switch (mapOp) { switch (mapOp) {
case MapOperations::None: assert(false && "Invalid map operation during map operation creation."); case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
} }
} }
@@ -180,10 +180,10 @@ void tileImageTensorByChannel(Value imageTensor,
ConversionPatternRewriter& rewriter) { ConversionPatternRewriter& rewriter) {
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType()); ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
size_t input_h = GET_IMAGE_HEIGHT(imageShape); size_t input_h = getImageHeight(imageShape);
size_t input_w = GET_IMAGE_WIDTH(imageShape); size_t input_w = getImageWidth(imageShape);
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize); size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize; size_t tileRest = getImageChannel(imageShape) % tileSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
@@ -201,7 +201,7 @@ void tileImageTensorByChannel(Value imageTensor,
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides); tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
} }
} }
} }
@@ -225,7 +225,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(outputTiles[outTile][outX][outY]); tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat); return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
} }
LogicalResult LogicalResult
@@ -271,7 +271,7 @@ Value createExtractSliceImg(Value valToSlice,
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides); return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
} }
Value indexImgValue(Value v, Value indexImgValue(Value v,
@@ -384,7 +384,7 @@ void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides); inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
} }
} }
} }
@@ -452,7 +452,7 @@ LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)}; SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type()); auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
Value shapeTensor = Value shapeTensor =
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType()); auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor); auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);

View File

@@ -9,24 +9,55 @@
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#define DEFINE_MAP_OP(opname) opname, #define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
const StringRef REPLICATION_ATTR_NAME = "replication_factor"; template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
inline constexpr mlir::StringRef REPLICATION_ATTR_NAME = "replication_factor";
using HSliceId = size_t; using HSliceId = size_t;
using CoreId = size_t; using CoreId = size_t;
@@ -58,51 +89,64 @@ constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
} }
template <class T> template <class T>
bool isVectorShape(const ArrayRef<T> shape) { bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
} }
template <class T> template <class T>
bool isMatrixShape(const ArrayRef<T> shape) { bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2; return shape.size() == 2;
} }
template <class T> template <class T>
bool isHVectorShape(const ArrayRef<T> shape) { bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1; return shape.size() == 2 && shape[0] == 1;
} }
template <class T> template <class T>
bool isVVectorShape(const ArrayRef<T> shape) { bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1; return shape.size() == 2 && shape[1] == 1;
} }
template <class T> template <class T>
T getVectorLength(const ArrayRef<T> shape) { T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape)); assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1]; return shape[0] != 1 ? shape[0] : shape[1];
} }
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); } inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
SmallVector<Value> sliceTensor( llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
SmallVector<Value> llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
DenseMap<CoreId, SmallVector<Value>> llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc); const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix( llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc); tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
tensor::SplatOp mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc); int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter); mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input); mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
/** /**
* Unpacks an optional pair vector into two size_t values. * Unpacks an optional pair vector into two size_t values.
@@ -126,7 +170,8 @@ void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t
* *
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid * @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
*/ */
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y); std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
/** /**
* Tiles the image tensor by channel. * Tiles the image tensor by channel.
@@ -140,10 +185,10 @@ std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> val
* @param tileSize The size of each tile. * @param tileSize The size of each tile.
* @param rewriter The ConversionPatternRewriter used for creating operations. * @param rewriter The ConversionPatternRewriter used for creating operations.
*/ */
void tileImageTensorByChannel(Value imageTensor, void tileImageTensorByChannel(mlir::Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles, llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
size_t tileSize, size_t tileSize,
ConversionPatternRewriter& rewriter); mlir::ConversionPatternRewriter& rewriter);
/** /**
* Creates an ImgConcatOp based on the given tiles. * Creates an ImgConcatOp based on the given tiles.
@@ -159,10 +204,10 @@ void tileImageTensorByChannel(Value imageTensor,
* *
* @return The created ImgConcatOp. * @return The created ImgConcatOp.
*/ */
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles, mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
Location& loc, mlir::Location& loc,
Type outputType); mlir::Type outputType);
/** /**
* @brief Verifies if the given input coordinates and padding values are within * @brief Verifies if the given input coordinates and padding values are within
@@ -177,7 +222,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
* @return LogicalResult Returns success if the coordinates and padding are * @return LogicalResult Returns success if the coordinates and padding are
* within bounds, failure otherwise. * within bounds, failure otherwise.
*/ */
LogicalResult mlir::LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y); verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
/** /**
@@ -207,8 +252,9 @@ verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY,
* @return std::optional<llvm::Twine> An error message if the input tensor could * @return std::optional<llvm::Twine> An error message if the input tensor could
* not be resolved into tiles. * not be resolved into tiles.
*/ */
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor, std::optional<llvm::Twine>
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles, resolveImgInputTiles(mlir::Value wholeInputTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
size_t channelTileCount, size_t channelTileCount,
size_t channelTileRest, size_t channelTileRest,
size_t input_w, size_t input_w,
@@ -258,6 +304,6 @@ void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcom
* @return The index of the result of the operation that produces the specified * @return The index of the result of the operation that produces the specified
* value. * value.
*/ */
int getResultIndex(Operation* op, Value v); int getResultIndex(mlir::Operation* op, mlir::Value v);
}; // namespace onnx_mlir }; // namespace onnx_mlir

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@@ -8,20 +9,41 @@
#include <fstream> #include <fstream>
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "ONNXToSpatialPass.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -39,15 +61,19 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp); IRRewriter rewriter(moduleOp);
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin(); auto entryFunc = getPimEntryFunc(moduleOp);
if (annotateReplication(funcOp, rewriter).failed()) { if (failed(entryFunc)) {
signalPassFailure();
return;
}
if (annotateReplication(*entryFunc, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n"; llvm::dbgs() << "Failed during annotation for replication analysis\n";
signalPassFailure(); signalPassFailure();
return; return;
} }
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>(); target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addIllegalOp<ONNXMatMulOp>(); target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXGemmOp>(); target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>(); target.addIllegalOp<ONNXConvOp>();
@@ -61,16 +87,9 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx); patterns.add<removeLRNPattern>(ctx);
if (useExperimentalConvImpl) { populateConvOpPatterns(patterns, ctx);
populateExperimentalTilingConvOpPattern(patterns, ctx);
populateExperimentalPoolingTilingPattern(patterns, ctx);
populateGemmToConvConversionPattern(patterns, ctx);
}
else {
populateTilingConvOpPattern(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx); populatePoolingTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx);
}
populateONNXConcatToTensorConcatPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx);
@@ -83,8 +102,8 @@ void ONNXToSpatialPass::runOnOperation() {
// Count the number of compute ops and check they do not exceed the core count // Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations()) for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<SpatWeightedCompute>(op)) if (isa<spatial::SpatWeightedCompute>(op))
computeOpsCount++; computeOpsCount++;
if (computeOpsCount > coresCount) { if (computeOpsCount > coresCount) {
@@ -101,22 +120,21 @@ void ONNXToSpatialPass::runOnOperation() {
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns)))) if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp); annotateWeightsConstants(*entryFunc);
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "spatial"); dumpModule(moduleOp, "spatial");
} }
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](arith::ConstantOp constantOp) { funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight = bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); }); llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
if (isAlwaysWeight) if (isAlwaysWeight)
constantOp->setAttr("weightAlways", UnitAttr::get(ctx)); markWeightAlways(constantOp);
}); });
} }
} // namespace spatial std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
using namespace mlir;
extern bool haveSameStaticShape(Value lhs, Value rhs);
namespace spatial {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -1,28 +1,20 @@
#pragma once #pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir { namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
// Experimental patterns.
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -10,7 +10,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy> template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> { struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx) RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {} : OpRewritePattern<OpTy>(ctx) {}

View File

@@ -49,11 +49,11 @@ LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& r
ShapedType xShape = mlir::cast<ShapedType>(X.getType()); ShapedType xShape = mlir::cast<ShapedType>(X.getType());
ShapedType wShape = mlir::cast<ShapedType>(W.getType()); ShapedType wShape = mlir::cast<ShapedType>(W.getType());
size_t input_w = GET_IMAGE_WIDTH(xShape); size_t input_w = getImageWidth(xShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape); size_t krn_h = getKernelHeight(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape); size_t krn_w = getKernelWidth(wShape);
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t inputTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue()); size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;

View File

@@ -15,21 +15,21 @@
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced; llvm::SmallPtrSet<mlir::Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum, ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun, std::function<mlir::Value(const mlir::Value&)> processFun,
ConversionPatternRewriter& rewriter) { mlir::ConversionPatternRewriter& rewriter) {
assert(processFun); assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum); auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum); auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); spatial::SpatYieldOp yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum); mlir::Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result); rewriter.setInsertionPointAfterValue(result);
Value processedResult = processFun(result); mlir::Value processedResult = processFun(result);
if (processedResult == result) { if (processedResult == result) {
// Sometimes we want processedResult to return the same value but do // Sometimes we want processedResult to return the same value but do
// something else with it (e.g. in softmax we want to broadcast the value // something else with it (e.g. in softmax we want to broadcast the value
@@ -42,10 +42,11 @@ ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum
return yieldOp.getNumOperands() - 1; return yieldOp.getNumOperands() - 1;
} }
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum, OpAndResNum
std::function<Value(const Value&, const Value&)> reduce, SpatialReducer::applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&)> preprocess, std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<Value(const Value&)> postprocess) { std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<mlir::Value(const mlir::Value&)> postprocess) {
if (preprocess) if (preprocess)
for (auto& computeOpAndResNum : computeOpsAndResNum) for (auto& computeOpAndResNum : computeOpsAndResNum)
@@ -55,18 +56,18 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// computeOp. In this case, we need to apply the reduction within-computef // computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction // Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute; std::unordered_map<mlir::Operation*, mlir::Value> lastValueForCompute;
for (auto& computeOpAndResNum : computeOpsAndResNum) { for (auto& computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum); auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); mlir::Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation()); auto it = lastValueForCompute.find(computeOp.getOperation());
if (it != lastValueForCompute.end()) { if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction // If we have already seen this computeOp, apply the reduction
// within-compute // within-compute
Value lastWithinComputeValue = it->second; mlir::Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp()); assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
@@ -85,12 +86,12 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
computeOpsAndResNum.clear(); computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size()); computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute) { for (auto& entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first); auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second; auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that // We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it // case no need to add it
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false; bool yieldOpUseFound = false;
for (auto& use : valueWithinCompute.getUses()) { for (auto& use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) { if (use.getOwner() == yieldOp.getOperation()) {
@@ -110,7 +111,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
computeOpsAndResNum.push_back({computeOp, resultNum}); computeOpsAndResNum.push_back({computeOp, resultNum});
} }
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc(); mlir::Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
// Recursive algorithm to reduce the inputs to a single one: // Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating // - Take two inputs at a time, and reduce them into a single one, updating
@@ -118,7 +119,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// - Repeat until there is only one input left. // - Repeat until there is only one input left.
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum); llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
while (computeOpsRef.size() > 1) { while (computeOpsRef.size() > 1) {
SmallVector<ComputeAndResNum> nextComputeOps; llvm::SmallVector<ComputeAndResNum> nextComputeOps;
nextComputeOps.reserve(computeOpsRef.size() / 2); nextComputeOps.reserve(computeOpsRef.size() / 2);
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) { for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
auto [firstCompute, firstResultNum] = computeOpsRef[i]; auto [firstCompute, firstResultNum] = computeOpsRef[i];
@@ -135,23 +136,23 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// the number of results) // the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates` // See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator()); auto yieldOpFirstCompute = mlir::cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp // Add a new operand to the block of the second computeOp
Block& secondBlock = secondCompute.getBody().front(); mlir::Block& secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); mlir::Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum = auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0]; secondCompute->getAttrOfType<mlir::DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1; auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp // Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator()); spatial::SpatYieldOp secondYield = mlir::cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum); mlir::Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation // Apply reduction operation
rewriter.setInsertionPoint(secondYield); rewriter.setInsertionPoint(secondYield);
Value reduced = reduce(formerRes2, formerRes1); mlir::Value reduced = reduce(formerRes2, formerRes1);
// Unfortunately, it is not possible to update the result in place, // Unfortunately, it is not possible to update the result in place,
// because we may have already referenced it by <computeOp, resultNum> // because we may have already referenced it by <computeOp, resultNum>
@@ -219,7 +220,7 @@ void SpatialReducer::finalizeReduceUpdates() {
// `opToReplacedCompute` // `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp]; auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp) if (!toComputeOp)
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp); toComputeOp = mlir::cast<spatial::SpatWeightedCompute>(toOp);
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!"); assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
@@ -234,31 +235,31 @@ void SpatialReducer::finalizeReduceUpdates() {
} }
} }
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) { mlir::Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates."); assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
Operation* opToCast; mlir::Operation* opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first); auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end()) if (it != opToReplacedCompute.end())
opToCast = it->second; opToCast = it->second;
else else
opToCast = opAndResNum.first; opToCast = opAndResNum.first;
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast); auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second); return computeOp.getResult(opAndResNum.second);
} }
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) { void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again // If we have already replaced the fromOp, we do not need to do it again
return; return;
} }
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp); auto oldComputeOp = mlir::cast<spatial::SpatWeightedCompute>(computeOp);
auto oldComputeOpNum = oldComputeOp->getNumOperands(); auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator()); auto yieldOp = mlir::cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map // No result was added, just add itself to the map
@@ -271,8 +272,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
// Create a new ComputeOp with the new result type, but same operands // Create a new ComputeOp with the new result type, but same operands
rewriter.setInsertionPoint(oldComputeOp); rewriter.setInsertionPoint(oldComputeOp);
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>( auto newComputeOp = spatial::SpatWeightedCompute::create(
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); rewriter, oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newComputeOp.getBody().takeBody(oldComputeOp.getBody()); newComputeOp.getBody().takeBody(oldComputeOp.getBody());
@@ -283,8 +284,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
// Since we replaced the old ComputeOp with a new one, we need to replace // Since we replaced the old ComputeOp with a new one, we need to replace
// all its results' uses // all its results' uses
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) { for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
Value oldResult = oldComputeOp.getResult(i); mlir::Value oldResult = oldComputeOp.getResult(i);
Value newResult = newComputeOp.getResult(i); mlir::Value newResult = newComputeOp.getResult(i);
// Replace the uses, except the uses of the compute ops which got deleted // Replace the uses, except the uses of the compute ops which got deleted
// previously // previously
@@ -298,9 +299,10 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
rewriter.eraseOp(oldComputeOp); rewriter.eraseOp(oldComputeOp);
} }
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles, mlir::Value
Location& loc, SpatialReducer::createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Type outputType) { mlir::Location& loc,
mlir::Type outputType) {
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates."); assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
@@ -309,8 +311,8 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
auto width = outputTiles[0].size(); auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size(); auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles( llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>> remappedOutputTiles(
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height))); tilesCount, llvm::SmallVector<llvm::SmallVector<mlir::Value>>(width, llvm::SmallVector<mlir::Value>(height)));
for (size_t t = 0; t < tilesCount; t++) for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++) for (size_t x = 0; x < width; x++)
@@ -320,25 +322,25 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType); return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
} }
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps, OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
Value biasTile, mlir::Value biasTile,
MapOperations mapOp) { MapOperations mapOp) {
std::function<Value(const Value&)> postprocessing = nullptr; std::function<mlir::Value(const mlir::Value&)> postprocessing = nullptr;
if (mapOp != MapOperations::None) { if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) { postprocessing = [&](const mlir::Value a) {
Value mapOperand = a; mlir::Value mapOperand = a;
if (biasTile) if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile); mapOperand = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, biasTile);
return createMapOperation(rewriter, mapOp, mapOperand); return createMapOperation(rewriter, mapOp, mapOperand);
}; };
} }
return this->applyReducePattern( return this->applyReducePattern(
computeOps, computeOps,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); }, [&](mlir::Value a, mlir::Value b) { return spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); },
/* preprocess = */ nullptr, /* preprocess = */ nullptr,
postprocessing); postprocessing);
} }

View File

@@ -3,6 +3,10 @@
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include <functional>
#include <unordered_map>
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -13,28 +17,28 @@ using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>; using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange { struct SpatialReducerChange {
Operation* fromOp; mlir::Operation* fromOp;
unsigned int fromOpResNum; unsigned int fromOpResNum;
Operation* toOp; mlir::Operation* toOp;
unsigned int toOpOperandNum; unsigned int toOpOperandNum;
}; };
using OpAndResNum = std::pair<Operation*, ResNum>; using OpAndResNum = std::pair<mlir::Operation*, ResNum>;
class SpatialReducer { class SpatialReducer {
public: public:
SpatialReducer(ConversionPatternRewriter& rewriter) SpatialReducer(mlir::ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {} : rewriter(rewriter) {}
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum, OpAndResNum applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce, std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<Value(const Value&)> preprocess, std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<Value(const Value&)> postprocess); std::function<mlir::Value(const mlir::Value&)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps, OpAndResNum applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
Value biasTile, mlir::Value biasTile,
MapOperations mapOp); MapOperations mapOp);
void finalizeReduceUpdates(); void finalizeReduceUpdates();
@@ -44,17 +48,17 @@ public:
finalizeReduceUpdates(); finalizeReduceUpdates();
} }
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles, mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Location& loc, mlir::Location& loc,
Type outputType); mlir::Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum); mlir::Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private: private:
[[nodiscard("computeOp result number gets updated")]] ResNum [[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum, applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun, std::function<mlir::Value(const mlir::Value&)> processFun,
ConversionPatternRewriter& rewriter); mlir::ConversionPatternRewriter& rewriter);
/** /**
* @brief Update the results of a ComputeOp. * @brief Update the results of a ComputeOp.
@@ -66,19 +70,19 @@ private:
* *
* @param computeOp The ComputeOp to update the results of. * @param computeOp The ComputeOp to update the results of.
*/ */
void updateResultsOfCompute(Operation* computeOp); void updateResultsOfCompute(mlir::Operation* computeOp);
ConversionPatternRewriter& rewriter; mlir::ConversionPatternRewriter& rewriter;
bool reducesFinalized = false; bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized // List of changes to be applied after the reduction is finalized
SmallVector<SpatialReducerChange, 4> reducerChanges; llvm::SmallVector<SpatialReducerChange, 4> reducerChanges;
// List of computeOps that need to be replaced with new results // List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate; llvm::SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute; std::unordered_map<mlir::Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced; static llvm::SmallPtrSet<mlir::Operation*, 16> oldComputeOpsReplaced;
}; };
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
namespace onnx_mlir { namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights) WeightSubdivider::WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights)
: weights(std::move(weights)) {} : weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); } bool WeightSubdivider::isEmpty() const { return weights.empty(); }
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract."); assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin(); auto it = weights.begin();
SmallVector<Value>& values = it->second.begin()->second; llvm::SmallVector<mlir::Value>& values = it->second.begin()->second;
long inputTile = it->first; long inputTile = it->first;
long outputTile = it->second.begin()->first; long outputTile = it->second.begin()->first;
@@ -21,7 +21,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
size_t n = std::min(amount, values.size()); size_t n = std::min(amount, values.size());
crossbarsUsed += n; crossbarsUsed += n;
SmallVector<Value> result; llvm::SmallVector<mlir::Value> result;
result.assign(values.begin(), values.begin() + n); result.assign(values.begin(), values.begin() + n);
if (n < values.size()) { if (n < values.size()) {
@@ -36,9 +36,9 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
return {inputTile, outputTile, crossbarsUsed - n, result}; return {inputTile, outputTile, crossbarsUsed - n, result};
} }
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) { llvm::SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
crossbarsUsed = 0; crossbarsUsed = 0;
SmallVector<TaggedWeights> result; llvm::SmallVector<TaggedWeights> result;
size_t remaining = n; size_t remaining = n;
while (remaining > 0 && !weights.empty()) { while (remaining > 0 && !weights.empty()) {

View File

@@ -4,11 +4,9 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <map> #include <map>
using namespace mlir;
using namespace std;
namespace onnx_mlir { namespace onnx_mlir {
/** /**
@@ -19,7 +17,7 @@ struct TaggedWeights {
long inputTile; long inputTile;
long outputTile; long outputTile;
size_t startingCrossbarIndex; size_t startingCrossbarIndex;
SmallVector<Value> weights; llvm::SmallVector<mlir::Value> weights;
}; };
/** /**
@@ -33,16 +31,16 @@ struct TaggedWeights {
*/ */
class WeightSubdivider { class WeightSubdivider {
private: private:
map<long, map<long, SmallVector<Value>>> weights; std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights;
size_t crossbarsUsed = 0; size_t crossbarsUsed = 0;
TaggedWeights popGroup(size_t amount); TaggedWeights popGroup(size_t amount);
public: public:
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights); WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights);
bool isEmpty() const; bool isEmpty() const;
SmallVector<TaggedWeights> popGroups(size_t n); llvm::SmallVector<TaggedWeights> popGroups(size_t n);
}; };
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -5,7 +5,7 @@ add_onnx_mlir_library(OMSpatialToGraphviz
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
OMCompilerOptions OMCompilerOptions
OMPIMCommon OMPimCommon
OMONNXOps OMONNXOps
SpatialOps SpatialOps

View File

@@ -10,6 +10,7 @@
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -199,12 +200,12 @@ private:
void SpatialToGraphvizPass::runOnOperation() { void SpatialToGraphvizPass::runOnOperation() {
ModuleOp module = getOperation(); ModuleOp module = getOperation();
// Get the first OP, must be a FuncOp auto entryFunc = getPimEntryFunc(module);
func::FuncOp func = *module.getOps<func::FuncOp>().begin(); if (failed(entryFunc)) {
if (!func) {
module->emitError("No FuncOp found in the begin of module");
signalPassFailure(); signalPassFailure();
return;
} }
func::FuncOp func = *entryFunc;
os << "digraph G {\n" os << "digraph G {\n"
<< "\tnode [style=filled,color=white];\n"; << "\tnode [style=filled,color=white];\n";

View File

@@ -1,21 +0,0 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPIMIncGen)
add_onnx_mlir_library(OMSpatialToPIM
SpatialToPIMPass.hpp
SpatialToPIMPass.cpp
SpatialToPIMCommon.cpp
DEPENDS
SpatialToPIMIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPIMCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -1,108 +0,0 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
template <class T>
size_t rangeLength(const iterator_range<T> range) {
return std::distance(range.begin(), range.end());
}
/**
* Retrieves the earliest operation that uses the given value within the value's
* block.
*
* @param value The value for which to find the earliest user operation.
* @return The earliest user operation that uses the given value within the
* current block.
*/
Operation* getEarliestUserWithinBlock(Value value);
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
const ArrayRef<int64_t> offsets,
const ArrayRef<int64_t> sizes,
const ArrayRef<int64_t> strides) {
// Check that all strides are 1
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
// Check offsets from right to left:
// The first offset_n at position n different from 0:
// - limits all sizes to the left to 1
// - limits size_n to dimension_n - offset_n
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
// Check sizes from right to left:
// The first size_n at position n different from shape_n limits all sizes to the left to 1
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
} // namespace onnx_mlir

View File

@@ -1,60 +0,0 @@
#pragma once
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace spat_to_pim {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace spat_to_pim
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<spat_to_pim::SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,20 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen)
add_onnx_mlir_library(OMSpatialToPim
SpatialToPimPass.cpp
SpatialToPimCommon.cpp
DEPENDS
SpatialToPimIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPimCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -3,10 +3,18 @@
#ifndef OP_BASE #ifndef OP_BASE
include "mlir/IR/PatternBase.td" include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def onnxToPimTransposeOp : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMMOp : Pat< def spatToPimVMMOp : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector, (PimVMMOp $weightIndex, $vector,

View File

@@ -5,9 +5,10 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include "SpatialToPIMCommon.hpp" #include "SpatialToPimCommon.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -53,7 +54,7 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue; return returnValue;
} }
Operation* getEarliestUserWithinBlock(Value value) { Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers(); auto users = value.getUsers();
assert(!users.empty()); assert(!users.empty());
@@ -66,23 +67,24 @@ Operation* getEarliestUserWithinBlock(Value value) {
return earliestUser; return earliestUser;
} }
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) { SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> { auto operandsAndUses =
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
return {operand, std::distance(operand.use_begin(), operand.use_end())}; return {operand, std::distance(operand.use_begin(), operand.use_end())};
}); });
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; }); sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
} }
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) { mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1); assert("Only support operations with a single result" && operation->getNumResults() == 1);
Value result = operation->getResult(0); mlir::Value result = operation->getResult(0);
auto resultType = result.getType(); auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType)); assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<Value> operands = getOpOperandsSortedByUses(operation); SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands = auto validOperands =
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; }); make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
auto bestOperand = validOperands.begin(); auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end()) if (bestOperand != validOperands.end())
@@ -90,8 +92,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera
auto resultShapedType = cast<ShapedType>(resultType); auto resultShapedType = cast<ShapedType>(resultType);
rewriter.setInsertionPoint(operation); rewriter.setInsertionPoint(operation);
return rewriter.create<tensor::EmptyOp>( return tensor::EmptyOp::create(
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -0,0 +1,52 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());
}
/**
* Retrieves the earliest operation that uses the given value within the value's
* block.
*
* @param value The value for which to find the earliest user operation.
* @return The earliest user operation that uses the given value within the
* current block.
*/
mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(mlir::Operation* op) {
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
}
} // namespace onnx_mlir

View File

@@ -1,8 +1,10 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@@ -16,20 +18,101 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "SpatialToPIMPass.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace pim; using namespace pim;
using namespace spat_to_pim;
void SpatialToPIMPass::runOnOperation() { namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel,
bool useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel,
bool useBroadcastOp,
IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
op);
}
static size_t countComputeLeafUsers(Value value) {
size_t leafUserCount = 0;
auto walkUses = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (isa<spatial::SpatWeightedCompute>(owner)) {
leafUserCount++;
continue;
}
if (!isChannelUseChainOp(owner))
llvm_unreachable("Channel use chain contains unsupported op");
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
self(owner->getResult(0), self);
}
};
walkUses(value, walkUses);
return leafUserCount;
}
void SpatialToPimPass::runOnOperation() {
coreId = 1; coreId = 1;
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>(); target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
populateWithGenerated(patterns); populateWithGenerated(patterns);
@@ -39,15 +122,21 @@ void SpatialToPIMPass::runOnOperation() {
return; return;
} }
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin(); auto entryFunc = getPimEntryFunc(moduleOp);
if (!funcOp) if (failed(entryFunc)) {
llvm_unreachable("No FuncOp found in the begin of module"); signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter); addResultBuffer(returnOp, rewriter);
allocateAndInitializeCoreLocalVariables(funcOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
signalPassFailure();
return;
}
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) { for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp); operationsToRemove.push_back(receiveOp);
@@ -73,10 +162,10 @@ void SpatialToPIMPass::runOnOperation() {
} }
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "pim"); dumpModule(moduleOp, "pim0");
} }
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc(); Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front(); auto& block = computeOp.getRegion().front();
@@ -124,7 +213,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Store to global memory // Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn]; Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>(loc, PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
yieldValue, yieldValue,
@@ -155,7 +245,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Store to global memory // Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn]; Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>( PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
@@ -174,23 +264,20 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// 1. Create a new ChannelOp // 1. Create a new ChannelOp
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Receive value through the channel // 2. Receive value through the channel. Broadcast is needed whenever the
// If this result is used by more than one user, then use a "Broadcast" // value eventually reaches more than one compute consumer, even through a
// channel operation. However, there is a special case: we have a single // chain of view-like ops.
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this bool useBroadcastOp = countComputeLeafUsers(result) > 1;
// case, we need to use a "Broadcast" channel operation. `addReceiveOps` addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
// will detect this case and update `useBroadcastOp` accordingly.
bool useBroadcastOp = (numResultUses > 1);
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
// 3. Send the value through the channel // 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
if (useBroadcastOp) if (useBroadcastOp)
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue); spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
else else
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue); spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
} }
// Use `HaltOp` instead of `YieldOp` // Use `HaltOp` instead of `YieldOp`
@@ -199,17 +286,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Replace `spat.compute` with `pim.core` // Replace `spat.compute` with `pim.core`
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
auto& coreOpBlocks = coreOp.getBody().getBlocks(); auto& coreOpBlocks = coreOp.getBody().getBlocks();
block.eraseArguments(0, block.getNumArguments()); block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block(); Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock); computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock);
rewriter.create<PimHaltOp>(computeOp.getLoc()); PimHaltOp::create(rewriter, computeOp.getLoc());
} }
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp(); auto* definingOp = value.getDefiningOp();
if (!definingOp) if (!definingOp)
@@ -246,20 +333,20 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr}; SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp); rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp}; SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
} }
}); });
} }
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands()); outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock()); rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) { for (auto returnValue : returnOp->getOperands()) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp(); Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) { if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!returnValueDefiningOp->hasAttr("weightAlways")); assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue); outputTensors.push_back(returnValue);
} }
else { else {
@@ -270,7 +357,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
} }
} }
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
@@ -279,9 +366,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>( auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc, loc,
tensorType, tensorType,
deviceTensor, deviceTensor,
@@ -301,16 +389,19 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType()); ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc); if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1); BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front(); Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front()); rewriter.setInsertionPoint(&block.front());
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp); inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp);
funcOp.eraseArgument(i); if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
} }
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove; llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
@@ -324,6 +415,9 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) { if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource()); tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape(); ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets(); ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes(); ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
@@ -357,12 +451,15 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
for (auto sliceOp : sliceOpsToRemove) for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty()) if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp); rewriter.eraseOp(sliceOp);
return success();
} }
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp, bool useBroadcastOp,
IRRewriter& rewriter) { IRRewriter& rewriter) {
auto& computeBlock = computeOp.getRegion().front(); auto& computeBlock = computeOp.getRegion().front();
@@ -375,71 +472,71 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue; Value receivedValue;
if (useBroadcastOp) if (useBroadcastOp)
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel); receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else else
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel); receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
blockArg.replaceAllUsesWith(receivedValue); Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) {
SmallVector<Operation*> clonedChain;
Value currentValue = consumerValue;
while (currentValue != channelSourceOp) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || !isChannelUseChainOp(definingOp))
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
clonedChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
} }
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp, IRMapping mapping;
mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) {
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
markOpToRemove(op);
}
replacementValue = cast<Value>(mapping.lookup(consumerValue));
}
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue);
}
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& channelTensorType, bool useBroadcastOp,
bool& useBroadcastOp,
IRRewriter& rewriter) { IRRewriter& rewriter) {
auto sourceOpUses = channelSourceOp.getUses(); auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users Operation* owner = use.getOwner();
if (useBroadcastOp == false) { if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
// if useBroadcastOp is false, then sourceOp must have only one user
assert(rangeLength(sourceOpUses) == 1);
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
auto reshapeOpUses = reshapeOp.getOutput().getUses();
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
if (reshapeOpUsesCount > 1)
useBroadcastOp = true;
}
}
for (auto& resultUse : sourceOpUses) {
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
if (computeUser) {
replaceBlockArgumentWithRecvOp( replaceBlockArgumentWithRecvOp(
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter); computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
continue; continue;
} }
if (!computeUser) { if (!isChannelUseChainOp(owner))
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner()); llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
if (!reshapeOp) {
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump(); markOpToRemove(owner);
resultUse.getOwner()->dump(); assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); self(owner->getResult(0), self);
}
};
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
} }
// The tensorType now becomes the one of the reshapeOp void SpatialToPimPass::markOpToRemove(Operation* op) {
channelTensorType = reshapeOp.getResult().getType(); if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
if (!computeUser)
llvm_unreachable("ReshapeOp users must be ComputeOps");
replaceBlockArgumentWithRecvOp(
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
} }
// Remove the reshapeOp, so that the sourceOp has no users void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
operationsToRemove.push_back(reshapeOp);
}
}
}
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
for (auto it : llvm::enumerate(returnOp.getOperands())) { for (auto it : llvm::enumerate(returnOp.getOperands())) {
Operation* returnOperand = it.value().getDefiningOp(); Operation* returnOperand = it.value().getDefiningOp();
@@ -458,7 +555,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
} }
} }
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp()); auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
@@ -468,15 +565,10 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt); auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
auto tensorType = receiveOp.getType();
Value receiveRes = receiveOp.getResult(); Value receiveRes = receiveOp.getResult();
// Check if the receiveOp value has more than one user bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
auto receiveUses = receiveRes.getUses(); addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
auto receiveUsesCount = rangeLength(receiveUses);
assert(receiveUsesCount > 0);
bool useBroadcastOp = receiveUsesCount > 1;
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
if (useBroadcastOp) { if (useBroadcastOp) {
// When receiving, we actually noticed that the value has more than one // When receiving, we actually noticed that the value has more than one
@@ -486,3 +578,7 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData()); rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
} }
} }
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir

View File

@@ -1,2 +1,2 @@
add_subdirectory(PIM) add_subdirectory(Pim)
add_subdirectory(Spatial) add_subdirectory(Spatial)

View File

@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "Dialect/PIM/PimOps.hpp"
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace pim {
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace pim
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -175,7 +175,33 @@ def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
}]; }];
} }
// Computation // Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{ let description = [{
@@ -197,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
return getOutBufMutable(); return getOutBufMutable();
} }
}]; }];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
} }
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> { def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {

View File

@@ -10,7 +10,7 @@
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -20,7 +20,7 @@ namespace pim {
void PimDialect::initialize() { void PimDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
>(); >();
} }
@@ -45,5 +45,5 @@ POPULATE_DEPENDENCIES(PimVExpOp)
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"

View File

@@ -12,7 +12,7 @@
#include <string> #include <string>
/// Include the auto-generated header files containing the declarations /// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"

View File

@@ -3,7 +3,6 @@ mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen) add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization add_onnx_mlir_library(OMPimBufferization
PimBufferizationPass.hpp
PimBufferizationPass.cpp PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp OpBufferizationInterfaces.cpp
@@ -14,7 +13,7 @@ add_onnx_mlir_library(OMPimBufferization
PimBufferizationIncGen PimBufferizationIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
OMPIMCommon OMPimCommon
PimOps PimOps
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE

View File

@@ -1,4 +1,4 @@
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -2,12 +2,10 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref); mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp" #include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
using namespace bufferization; using namespace bufferization;
@@ -76,6 +76,32 @@ struct MemCopyDevToHostOpInterface
} }
}; };
struct TransposeOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> { struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -176,6 +202,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx); PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx); PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx); PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
#ifndef OP_BASE #ifndef OP_BASE
include "mlir/IR/PatternBase.td" include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td" include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat< def memrefCopyToPimMemCopyOp : Pat<

View File

@@ -5,14 +5,39 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace pim; using namespace pim;
namespace onnx_mlir {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace
void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation(); auto moduleOp = getOperation();
@@ -64,19 +89,22 @@ void PimBufferizationPass::runOnOperation() {
annotateWeightsMemrefs(moduleOp, funcOp); annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "pim_buf"); dumpModule(moduleOp, "pim1_buff");
} }
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = !getGlobalOp->getUsers().empty() bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }); && all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) { if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx)); markWeightAlways(getGlobalOp);
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx)); markWeightAlways(globalMemrefOp);
} }
}); });
} }
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -7,6 +7,7 @@ add_onnx_mlir_library(SpatialOps
Transforms/SpatialBufferizableOpInterface.cpp Transforms/SpatialBufferizableOpInterface.cpp
DEPENDS DEPENDS
OMONNXIncGen
OMSpatialIncGen OMSpatialIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC

View File

@@ -25,7 +25,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
LogicalResult SpatImgConcatOp::verify() { LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType()); auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape); size_t img_w = getImageWidth(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape); size_t img_h = getImageHeight(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape); size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize; size_t channelTileRest = img_c % crossbarSize;
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
return emitError("Invalid input type, must be ShapedType"); return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1 // N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1) if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1"); return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape); size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct: // Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that // - CASE1: last tile of pixel, if there is some rest it must match that
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) { Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands(); auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType()); auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape); size_t img_w = getImageWidth(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape); size_t img_h = getImageHeight(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape); size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());

View File

@@ -17,8 +17,8 @@
#include <cstdint> #include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -34,7 +34,7 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType()); auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref // Alloc an output memref
return rewriter.create<memref::AllocOp>(loc, memrefResultType); return memref::AllocOp::create(rewriter, loc, memrefResultType);
} }
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
@@ -134,7 +134,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
memrefOperands.push_back(outputTensor); memrefOperands.push_back(outputTensor);
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes(); Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue); replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -169,10 +169,12 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
// Alloc an output memref // Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue = Value newValue = ToTy::create(rewriter,
rewriter op->getLoc(),
.create<ToTy>( outputTensor.getType(),
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor) cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutRes(); .getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue); replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -213,8 +215,8 @@ struct ChannelReceiveOpInterface
if (failed(srcCoreId)) if (failed(srcCoreId))
return failure(); return failure();
Value newValue = rewriter Value newValue = pim::PimReceiveOp::create(rewriter,
.create<pim::PimReceiveOp>(op->getLoc(), op->getLoc(),
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize), rewriter.getI32IntegerAttr(numElements * elementSize),
@@ -300,7 +302,8 @@ struct ChannelBroadcastReceiveOpInterface
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements(); auto outputType = cast<ShapedType>(outputTensor.getType());
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>(); auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) { if (!channelNewOp) {
@@ -323,7 +326,8 @@ struct ChannelBroadcastReceiveOpInterface
} }
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(), auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
bufferAllocation, bufferAllocation,
@@ -356,7 +360,8 @@ struct ChannelBroadcastSendOpInterface
} }
/* /*
* Turn the channel send to pim.send * Turn the channel send into a device-to-host copy into the shared
* broadcast buffer that receive ops load from later.
*/ */
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
RewriterBase& rewriter, RewriterBase& rewriter,
@@ -389,8 +394,19 @@ struct ChannelBroadcastSendOpInterface
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter); bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
} }
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef}); pim::PimMemCopyDevToHostOp::create(rewriter,
op->getLoc(),
bufferAllocation.getType(),
bufferAllocation,
srcMemRef,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(op);
return success(); return success();
} }
}; };
@@ -469,7 +485,8 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr(); auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr(); auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(), Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
op->getLoc(),
outputTensor.getType(), outputTensor.getType(),
weightIndices, weightIndices,
xKernelPositions, xKernelPositions,

View File

@@ -4,14 +4,12 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry); void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,7 +1,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"

View File

@@ -1,6 +1,6 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -0,0 +1,618 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment = {}) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
if (!mapOp.getInputs().empty())
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
Attribute attr;
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
return failure();
return attr;
}
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return failure();
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
if (!initType || !initType.hasStaticShape())
return failure();
auto fillValue = getConstantMapYield(mapOp);
if (failed(fillValue))
return failure();
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
if (elementByteWidth == 0)
return failure();
size_t totalBytes = initType.getNumElements() * elementByteWidth;
rewriter.setInsertionPoint(mapOp);
pim::PimMemCopyHostToDevOp::create(rewriter,
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.eraseOp(mapOp);
return success();
}
};
struct StaticSubviewInfo {
Value source;
SmallVector<int64_t> sourceShape;
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
};
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
static int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (copyOp.getSize() != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = copyOp.getSrcOffset()
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = copyOp.getDstOffset()
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : copyOp.getDst(),
splitSrc ? srcSubview->source : copyOp.getSrc(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numElements = resultTensorType.getNumElements();
if (numElements < 0)
return failure();
Attribute fillValue;
SmallVector<ConstantSubviewCopy> copies;
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
SmallVector<Value> pendingAliases;
pendingAliases.push_back(allocOp.getResult());
while (!pendingAliases.empty()) {
Value alias = pendingAliases.pop_back_val();
for (Operation* user : alias.getUsers()) {
if (!visitedAliases.insert(user).second)
continue;
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
if (mapOp.getInit() != alias)
return failure();
auto maybeFillValue = getConstantMapYield(mapOp);
if (failed(maybeFillValue))
return failure();
if (fillValue && fillValue != *maybeFillValue)
return failure();
fillValue = *maybeFillValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
offsets.push_back(*staticOffset);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
strides.push_back(*staticStride);
}
for (Operation* subviewUser : subviewOp->getUsers()) {
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
if (copyOp.getTarget() != subviewOp.getResult())
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
if (failed(denseAttr))
return failure();
copies.push_back({*denseAttr, offsets, strides, copyOp});
continue;
}
return failure();
}
continue;
}
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
continue;
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
pendingAliases.push_back(castOp.getResult());
continue;
}
return failure();
}
}
if (!fillValue)
return failure();
SmallVector<Attribute> resultValues(numElements, fillValue);
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
});
for (const ConstantSubviewCopy& copy : copies) {
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
SmallVector<int64_t> sourceIndices =
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
SmallVector<int64_t> resultIndices;
resultIndices.reserve(sourceIndices.size());
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
resultIndices.push_back(offset + sourceIndex * stride);
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
resultValues[resultLinearIndex] = value;
}
}
return DenseElementsAttr::get(resultTensorType, resultValues);
}
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
MemRefType globalType = resultType;
auto newGlobal = createFoldedGlobal(moduleOp,
transposeOp.getLoc(),
globalType,
*transposedAttr,
sourceGlobal.getName().str() + "__folded_transpose",
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
return success();
}
};
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
if (failed(foldedAttr))
return failure();
auto allocType = cast<MemRefType>(allocOp.getType());
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
SmallVector<Operation*> opsToErase;
SmallVector<memref::CastOp> castsToReplace;
bool allLiveUsersAreCoreOps = true;
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
opsToErase.push_back(user);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
castsToReplace.push_back(castOp);
continue;
}
if (!isa<pim::PimCoreOp>(user))
return failure();
}
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
})) {
allLiveUsersAreCoreOps = false;
}
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
})) {
return failure();
}
if (allLiveUsersAreCoreOps) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
for (memref::CastOp castOp : castsToReplace)
preservedUsers.insert(castOp);
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
for (memref::CastOp castOp : castsToReplace) {
rewriter.setInsertionPoint(castOp);
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
rewriter.replaceOp(castOp, replacementCast);
if (allLiveUsersAreCoreOps)
markWeightAlways(replacementCast.getDefiningOp());
}
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
rewriter.eraseOp(subviewUser);
if (op->use_empty())
rewriter.eraseOp(op);
}
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim2_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,175 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<arith::ConstantOp,
memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
spatial::SpatChannelNewOp>(op);
}
static bool isHostAddressableValue(Value value) {
while (true) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return true;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return false;
}
}
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
StringRef getArgument() const override { return "verify-pim-host-pass"; }
StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR";
}
PimHostVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)))
hasFailure = true;
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp)))
hasFailure = true;
continue;
}
if (!isAddressOnlyHostOp(&op)) {
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
hasFailure = true;
continue;
}
if (failed(verifyAddressOnlyHostOp(&op)))
hasFailure = true;
}
}
if (hasFailure)
signalPassFailure();
}
private:
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen";
hasFailure = true;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
hasFailure = true;
continue;
}
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
if (!isHostAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc());
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc());
return success();
}
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
if (isHostAddressableValue(source))
return success();
op->emitOpError("depends on a value that still requires host-side execution");
return failure();
}
};
} // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
} // namespace onnx_mlir

View File

@@ -3,23 +3,26 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
#include <string>
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
std::unique_ptr<Pass> createONNXToSpatialPass(); std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<Pass> createSpatialToGraphvizPass(); std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<Pass> createSpatialToPIMPass(); std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<Pass> createBufferizePimPass(); std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<Pass> createEmitPimJsonPass(); std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<Pass> createMessagePass(std::string message); std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
std::unique_ptr<Pass> createCountInstructionPass(); std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
std::unique_ptr<mlir::Pass> createCountInstructionPass();
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -12,8 +13,8 @@
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -40,28 +41,27 @@ PimAccelerator::PimAccelerator()
acceleratorTargets.push_back(this); acceleratorTargets.push_back(this);
} }
PimAccelerator::~PimAccelerator() { delete instance; }
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; } uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
void PimAccelerator::addPasses(OwningOpRef<ModuleOp>& module, void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
PassManager& pm, mlir::PassManager& pm,
EmissionTargetType& emissionTarget, EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const { std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
addPassesPim(module, pm, emissionTarget, outputNameNoExt); addPassesPim(module, pm, emissionTarget, outputNameNoExt);
} }
void PimAccelerator::registerDialects(DialectRegistry& registry) const { void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
registry.insert<tensor::TensorDialect>(); registry.insert<mlir::tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>(); registry.insert<mlir::tosa::TosaDialect>();
registry.insert<bufferization::BufferizationDialect>(); registry.insert<mlir::bufferization::BufferizationDialect>();
registry.insert<pim::PimDialect>(); registry.insert<pim::PimDialect>();
registry.insert<spatial::SpatialDialect>(); registry.insert<spatial::SpatialDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry); mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry); mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerOpBufferizationInterfaces(registry); pim::registerOpBufferizationInterfaces(registry);
@@ -71,8 +71,10 @@ void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
registerPass(createONNXToSpatialPass); registerPass(createONNXToSpatialPass);
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPIMPass); registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass); registerPass(createBufferizePimPass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimHostVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }
@@ -81,26 +83,26 @@ void PimAccelerator::configurePasses() const {
// TODO: This does nothing for now. // TODO: This does nothing for now.
} }
MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const { mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const {
// Do not convert tensor types to memref types. // Do not convert tensor types to memref types.
return nullptr; return nullptr;
} }
void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const { void PimAccelerator::conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const {
target.addLegalDialect<pim::PimDialect>(); target.addLegalDialect<pim::PimDialect>();
} }
void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns, void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
TypeConverter& typeConverter, mlir::TypeConverter& typeConverter,
MLIRContext* ctx) const { mlir::MLIRContext* ctx) const {
// TODO: Add patterns for conversion // TODO: Add patterns for conversion
} }
void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {} void PimAccelerator::conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const {}
void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns, void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
LLVMTypeConverter& typeConverter, mlir::LLVMTypeConverter& typeConverter,
MLIRContext* ctx) const { mlir::MLIRContext* ctx) const {
// We should not need this, since we offload it all to PIM. // We should not need this, since we offload it all to PIM.
} }

View File

@@ -18,8 +18,6 @@ public:
PimAccelerator(PimAccelerator&) = delete; PimAccelerator(PimAccelerator&) = delete;
void operator=(const PimAccelerator&) = delete; void operator=(const PimAccelerator&) = delete;
~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations /// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance. /// return the existing instance.
static PimAccelerator* getInstance(); static PimAccelerator* getInstance();

View File

@@ -102,46 +102,13 @@ def gen_c(inputs, outputs, entry, so_name):
if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}} if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}}
""")) """))
# Output printing + optional per-output CSV dump # Optional per-output CSV dump
out_blocks=[]
csv_write_blocks=[] csv_write_blocks=[]
for oi,name,et,shape in outputs: for oi,name,et,shape in outputs:
if et not in DTYPES: if et not in DTYPES:
raise ValueError(f"Unsupported dtype for output '{name}': {et}") raise ValueError(f"Unsupported dtype for output '{name}': {et}")
cty, pfmt, _ = DTYPES[et] cty, pfmt, _ = DTYPES[et]
safe = esc(name) safe = esc(name)
out_blocks.append(textwrap.dedent(f"""
// ---- Output {oi}: "{safe}" ----
{{
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
int64_t rank = omTensorGetRank(t);
int64_t const *shape = omTensorGetShape(t);
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
printf("Output {oi} ('{safe}'): shape=[");
for (int64_t k=0;k<rank;k++) printf("%ld%s",(long)shape[k], (k+1<rank)?",":"");
printf("]\\n");
if (rank == 2) {{
int64_t R = shape[0], C = shape[1];
for (int64_t r=0; r<R; ++r) {{
for (int64_t c=0; c<C; ++c) {{
long long idx = r*C + c;
printf("{pfmt}%s", p[idx], (c+1<C)?", ":"");
}}
printf("\\n");
}}
}} else {{
// Flattened vector with indices
for (long long i=0;i<numel;i++) {{
printf("[%lld]={pfmt}%s", i, p[i], (i+1<numel)?", ":"\\n");
}}
}}
}}
"""))
# Per-output CSV writer into --save-csv-dir
csv_write_blocks.append(textwrap.dedent(f""" csv_write_blocks.append(textwrap.dedent(f"""
if (save_csv_dir) {{ if (save_csv_dir) {{
// Build "DIR/output{oi}_<sanitized name>.csv" // Build "DIR/output{oi}_<sanitized name>.csv"
@@ -227,9 +194,6 @@ int main(int argc, char **argv) {{
OMTensorList *out_list = {entry}(in_list); OMTensorList *out_list = {entry}(in_list);
if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}} if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}}
// ---- Print full outputs ----
{"".join(out_blocks)}
// ---- Optional per-output CSV dump ---- // ---- Optional per-output CSV dump ----
{"".join(csv_write_blocks)} {"".join(csv_write_blocks)}
@@ -240,7 +204,7 @@ int main(int argc, char **argv) {{
}} }}
""" """
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None): def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None, verbose=True):
ins, outs = onnx_io(network_onnx) ins, outs = onnx_io(network_onnx)
out_c = out or "runner.c" out_c = out or "runner.c"
so_abs = os.path.abspath(network_so) so_abs = os.path.abspath(network_so)
@@ -260,6 +224,7 @@ set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)})
target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so) target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so)
""" """
pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake) pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake)
if verbose:
print(f"[OK] Wrote {out_c}") print(f"[OK] Wrote {out_c}")
print("[OK] Wrote CMakeLists.txt") print("[OK] Wrote CMakeLists.txt")

Binary file not shown.

Binary file not shown.

View File

@@ -1,12 +1,16 @@
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from colorama import Fore, Style from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter
def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count): def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, cwd=None, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count # Define the arguments, with the possibility to set crossbar size and count
args = [ args = [
network_path, network_path,
"-o",
output_base,
"--maccel=PIM", "--maccel=PIM",
"--EmitPimCodegen", "--EmitPimCodegen",
# "--use-experimental-conv-impl=true", # "--use-experimental-conv-impl=true",
@@ -14,16 +18,15 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, cro
f"--crossbar-count={crossbar_count}", f"--crossbar-count={crossbar_count}",
] ]
# Run the executable with the arguments
try: try:
result = subprocess.run( run_command_with_reporter(
[str(raptor_onnx_path)] + [str(arg) for arg in args], [str(raptor_onnx_path)] + [str(arg) for arg in args],
check=True, cwd=cwd,
capture_output=True, reporter=reporter,
text=True,
) )
print(result.stdout + Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) if reporter is None:
except subprocess.CalledProcessError as e: print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
print(Fore.RED + "Error executing ONNX-MLIR:") except subprocess.CalledProcessError:
print(e.stderr + Style.RESET_ALL) if reporter is None:
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
raise raise

View File

@@ -0,0 +1,76 @@
import errno
import os
import pty
import selectors
import subprocess
MAX_ERROR_OUTPUT_BYTES = 8192
def _read_chunk(fd, treat_eio_as_eof=False):
try:
return os.read(fd, 4096)
except OSError as exc:
if treat_eio_as_eof and exc.errno == errno.EIO:
return b""
raise
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector()
recent_output = bytearray()
try:
selector.register(fd, selectors.EVENT_READ)
while selector.get_map():
for key, _ in selector.select():
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
if not data:
selector.unregister(key.fileobj)
os.close(key.fileobj)
continue
reporter._clear()
os.write(1, data)
reporter._render()
recent_output.extend(data)
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
finally:
selector.close()
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
def run_command_with_reporter(cmd, cwd=None, reporter=None):
if reporter is None:
subprocess.run(cmd, cwd=cwd, check=True)
return
try:
master_fd, slave_fd = pty.openpty()
except OSError:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
_stream_output(process.stdout.fileno(), process, reporter)
return
try:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=slave_fd,
stderr=slave_fd,
)
finally:
os.close(slave_fd)
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)

View File

@@ -1,22 +1,58 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import shlex
import signal
import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from validate_one import validate_network from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
def format_command(cmd):
if isinstance(cmd, (list, tuple)):
return shlex.join(str(arg) for arg in cmd)
return str(cmd)
def format_return_status(returncode):
if returncode < 0:
signal_num = -returncode
try:
signal_name = signal.Signals(signal_num).name
except ValueError:
signal_name = "UNKNOWN"
return f"Program terminated by signal {signal_name} ({signal_num})."
return f"Program exited with code {returncode}."
def print_validation_error(reporter, rel, exc):
reporter.suspend()
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
file=sys.stderr, flush=True)
if isinstance(exc, subprocess.CalledProcessError):
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
print("Retry command:", file=sys.stderr, flush=True)
print(format_command(exc.cmd), file=sys.stderr, flush=True)
else:
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
print("=" * 72, file=sys.stderr, flush=True)
reporter.resume()
def main(): def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", required=True, help="Path to the Raptor compiler binary.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
ap.add_argument("--onnx-include-dir", required=True, help="Path to OnnxMlirRuntime include directory.") ap.add_argument("--onnx-include-dir", help="Path to OnnxMlirRuntime include directory.")
ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).") ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).")
ap.add_argument("--simulator-dir", default=None, ap.add_argument("--simulator-dir", default=None,
help="Path to pim-simulator crate root (default: auto-detected relative to script).") help="Path to pim-simulator crate root (default: auto-detected relative to script).")
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.") ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
ap.add_argument("--crossbar-size", type=int, default=64) ap.add_argument("--crossbar-size", type=int, default=64)
ap.add_argument("--crossbar-count", type=int, default=8) ap.add_argument("--crossbar-count", type=int, default=8)
ap.add_argument("--clean", action="store_true",
help="Remove generated validation artifacts under each model workspace and exit.")
a = ap.parse_args() a = ap.parse_args()
script_dir = Path(__file__).parent.resolve() script_dir = Path(__file__).parent.resolve()
@@ -34,32 +70,67 @@ def main():
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL) print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
sys.exit(1) sys.exit(1)
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL) if a.clean:
removed_count = 0
for onnx_path in onnx_files:
removed_count += len(clean_workspace_artifacts(onnx_path.parent, onnx_path.stem))
print(Style.BRIGHT + f"Removed {removed_count} generated artifact path(s)." + Style.RESET_ALL)
sys.exit(0)
missing_args = []
if not a.raptor_path:
missing_args.append("--raptor-path")
if not a.onnx_include_dir:
missing_args.append("--onnx-include-dir")
if missing_args:
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
print(f"Operations root: {operations_dir}")
print("=" * 72)
results = {} # relative_path -> passed results = {} # relative_path -> passed
for onnx_path in onnx_files: reporter = ProgressReporter(len(onnx_files))
for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir) rel = onnx_path.relative_to(operations_dir)
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}" try:
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL)
passed = validate_network( passed = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold, threshold=a.threshold,
reporter=reporter,
model_index=index,
model_total=len(onnx_files),
) )
results[str(rel)] = passed results[str(rel)] = passed
except subprocess.CalledProcessError as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
except Exception as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
reporter.finish()
# Summary # Summary
n_passed = sum(results.values()) n_passed = sum(1 for passed in results.values() if passed)
n_total = len(results) n_total = len(results)
print("\n" + Style.BRIGHT + "=" * 60) status_width = len("Result")
print(" Summary") path_width = max(len("Operation"), *(len(rel) for rel in results))
print("=" * 60 + Style.RESET_ALL) separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
print(separator)
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
print(separator)
for rel, passed in results.items(): for rel, passed in results.items():
status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL" plain_status = "PASS" if passed else "FAIL"
print(f" {rel}: {status}" + Style.RESET_ALL) status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL) Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
print(f"| {rel.ljust(path_width)} | {status} |")
print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
sys.exit(0 if n_passed == n_total else 1) sys.exit(0 if n_passed == n_total else 1)

View File

@@ -2,32 +2,148 @@ import argparse
import json import json
import numpy as np import numpy as np
import subprocess import subprocess
import shutil
import sys
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
from raptor import compile_with_raptor from raptor import compile_with_raptor
from gen_network_runner import gen_network_runner from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir): STAGE_COUNT = 6
subprocess.run([raptor_path, network_onnx_path, "--EmitONNXIR"], check=True) GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
subprocess.run([raptor_path, network_onnx_path], check=True)
parent = network_onnx_path.parent
class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
self.total_models = total_models
self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model)
self.completed_steps = 0
self.current_label = ""
self.enabled = True
self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False
def _clear(self):
if self.enabled:
sys.stdout.write("\033[2K\r")
def _render(self):
if not self.enabled or self.suspended:
return
bar_width = 24
filled = int(bar_width * self.completed_steps / self.total_steps)
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
if len(prefix_text) > self.columns:
prefix_text = f"{self.completed_steps}/{self.total_steps}"
label = f" {self.current_label}" if self.current_label else ""
available_label_width = max(0, self.columns - len(prefix_text))
label = label[:available_label_width]
if prefix_text.startswith("["):
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
else:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
sys.stdout.flush()
def log(self, message="", color=None):
if self.enabled:
self._clear()
if color:
print(color + message + Style.RESET_ALL)
else:
print(message)
self._render()
def set_stage(self, model_index, model_total, model_name, stage_name):
self.current_label = f"[{model_index}/{model_total}] {model_name} · {stage_name}"
self._render()
def advance(self):
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
self._render()
def suspend(self):
self.suspended = True
self._clear()
sys.stdout.flush()
def resume(self):
self.suspended = False
def finish(self):
if self.enabled:
self.suspended = True
self._clear()
sys.stdout.flush()
def run_command(cmd, cwd=None, reporter=None):
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
def clean_workspace_artifacts(workspace_dir, model_stem):
workspace_dir = Path(workspace_dir)
removed_paths = []
def remove_path(path):
if path.is_symlink() or path.is_file():
path.unlink(missing_ok=True)
removed_paths.append(path)
elif path.is_dir():
shutil.rmtree(path)
removed_paths.append(path)
for name in GENERATED_DIR_NAMES:
remove_path(workspace_dir / name)
for suffix in (".onnx.mlir", ".so", ".tmp"):
remove_path(workspace_dir / f"{model_stem}{suffix}")
return removed_paths
def print_stage(reporter, model_index, model_total, model_name, title):
stage_colors = {
"Compile ONNX": Fore.BLUE,
"Build Runner": Fore.MAGENTA,
"Generate Inputs": Fore.YELLOW,
"Run Reference": Fore.GREEN,
"Compile PIM": Fore.CYAN,
"Run Simulator": Fore.MAGENTA,
"Compare Outputs": Fore.YELLOW,
}
color = stage_colors.get(title, Fore.WHITE)
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
reporter.set_stage(model_index, model_total, model_name, title)
def print_info(reporter, message):
reporter.log(f" {message}")
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
stem = network_onnx_path.stem stem = network_onnx_path.stem
so_path = parent / f"{stem}.so" onnx_ir_base = raptor_dir / stem
mlir_path = parent / f"{stem}.onnx.mlir" runner_base = runner_dir / stem
tmp_path = parent / f"{stem}.tmp" run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"], reporter=reporter)
moved_so = runner_dir / so_path.name run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter)
moved_mlir = raptor_dir / mlir_path.name network_so_path = runner_base.with_suffix(".so")
so_path.rename(moved_so) network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
mlir_path.rename(moved_mlir) onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
tmp_path.unlink(missing_ok=True) return network_so_path, network_mlir_path
return moved_so, moved_mlir
def build_onnx_runner(source_dir, build_dir): def build_onnx_runner(source_dir, build_dir, reporter=None):
subprocess.run(["cmake", source_dir], cwd=build_dir, check=True) run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter)
subprocess.run(["cmake", "--build", ".", "-j"], cwd=build_dir, check=True) run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter)
return build_dir / "runner" return build_dir / "runner"
@@ -41,11 +157,12 @@ def build_dump_ranges(config_path, outputs_descriptor):
return ",".join(ranges) return ",".join(ranges)
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
subprocess.run( run_command(
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", ["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, check=True cwd=simulator_dir,
reporter=reporter,
) )
@@ -64,66 +181,122 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor):
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3): def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3):
all_passed = True all_passed = True
rows = []
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
csv_name = f"output{oi}_{name}.csv" csv_name = f"output{oi}_{name}.csv"
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape) runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64)))) max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
passed = max_diff <= threshold passed = max_diff <= threshold
status = Fore.GREEN + "[PASS]" if passed else Fore.RED + "[FAIL]" rows.append((name, f"{max_diff:.6e}", passed))
print(f" {name}: max diff = {max_diff:.6e} {status}" + Style.RESET_ALL)
if not passed: if not passed:
all_passed = False all_passed = False
name_width = max(len("Output"), *(len(name) for name, _, _ in rows))
diff_width = max(len("Max diff"), *(len(diff) for _, diff, _ in rows))
result_width = len("Result")
separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+"
print(separator)
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
print(separator)
for name, diff_text, passed in rows:
status_text = ("PASS" if passed else "FAIL").ljust(result_width)
status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL
print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |")
print(separator)
return all_passed return all_passed
def validate_network(network_onnx_path, raptor_path, onnx_include_dir, def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3): simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3,
reporter=None, model_index=1, model_total=1):
network_onnx_path = Path(network_onnx_path).resolve() network_onnx_path = Path(network_onnx_path).resolve()
raptor_path = Path(raptor_path).resolve() raptor_path = Path(raptor_path).resolve()
onnx_include_dir = Path(onnx_include_dir).resolve() onnx_include_dir = Path(onnx_include_dir).resolve()
simulator_dir = Path(simulator_dir).resolve() simulator_dir = Path(simulator_dir).resolve()
owns_reporter = reporter is None
reporter = reporter or ProgressReporter(model_total)
workspace_dir = network_onnx_path.parent workspace_dir = network_onnx_path.parent
clean_workspace_artifacts(workspace_dir, network_onnx_path.stem)
raptor_dir = workspace_dir / "raptor" raptor_dir = workspace_dir / "raptor"
runner_dir = workspace_dir / "runner" runner_dir = workspace_dir / "runner"
runner_build_dir = runner_dir / "build" runner_build_dir = runner_dir / "build"
Path.mkdir(raptor_dir, exist_ok=True) Path.mkdir(raptor_dir, exist_ok=True)
Path.mkdir(runner_build_dir, parents=True, exist_ok=True) Path.mkdir(runner_build_dir, parents=True, exist_ok=True)
print(Style.BRIGHT + "\nCompiling the onnx network:" + Style.RESET_ALL) reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
network_so_path, network_mlir_path = compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir) f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
failed_with_exception = False
print(Style.BRIGHT + "\nGenerating and building the runner:" + Style.RESET_ALL) try:
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
runner_path = build_onnx_runner(runner_dir, runner_build_dir) network_so_path, network_mlir_path = compile_onnx_network(
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter)
print_info(reporter, f"MLIR saved to {network_mlir_path}")
print_info(reporter, f"Shared library saved to {network_so_path}")
reporter.advance()
print(Style.BRIGHT + "\nGenerating random inputs:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
print_info(reporter, f"Runner built at {runner_path}")
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Generate Inputs")
inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path) inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path)
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor) inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor)
flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs") flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs")
print_info(reporter, f"Saved {len(inputs_list)} input file(s) to {workspace_dir / 'inputs'}")
reporter.advance()
print(Style.BRIGHT + "\nRunning inference with the runner:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Reference")
out_dir = workspace_dir / "outputs" out_dir = workspace_dir / "outputs"
Path.mkdir(out_dir, exist_ok=True) Path.mkdir(out_dir, exist_ok=True)
run_cmd = [runner_path, *flags] run_cmd = [runner_path, *flags]
run_cmd += ["--save-csv-dir", f"{out_dir}"] run_cmd += ["--save-csv-dir", f"{out_dir}"]
subprocess.run(run_cmd, cwd=runner_build_dir, check=True) run_command(run_cmd, cwd=runner_build_dir, reporter=reporter)
print_info(reporter, f"Reference outputs saved to {out_dir}")
reporter.advance()
print(Style.BRIGHT + "\nCompiling for PIM with Raptor:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
compile_with_raptor(network_mlir_path, raptor_path, crossbar_size, crossbar_count) compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count,
cwd=raptor_dir, reporter=reporter)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance()
print(Style.BRIGHT + "\nRunning PIM simulation:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Simulator")
pim_dir = raptor_dir / "pim" pim_dir = raptor_dir / "pim"
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list) write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list)
simulation_dir = workspace_dir / "simulation" simulation_dir = workspace_dir / "simulation"
Path.mkdir(simulation_dir, exist_ok=True) Path.mkdir(simulation_dir, exist_ok=True)
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor) dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
output_bin_path = simulation_dir / "out.bin" output_bin_path = simulation_dir / "out.bin"
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges) run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter)
print_info(reporter, f"Simulator output saved to {output_bin_path}")
reporter.advance()
print(Style.BRIGHT + "\nValidating the results:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs")
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor) sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
return validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold) reporter.suspend()
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
reporter.resume()
reporter.advance()
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
except Exception:
failed_with_exception = True
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
reporter.suspend()
raise
finally:
if not failed_with_exception:
reporter.log("=" * 72)
if owns_reporter:
reporter.finish()
if __name__ == '__main__': if __name__ == '__main__':