Compare commits
10 Commits
143c8f960a
...
584ca0b3c2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
584ca0b3c2 | ||
|
|
1348bb1c97 | ||
|
|
825188cc89 | ||
|
|
7202a4317d | ||
|
|
d4efa64b96 | ||
|
|
fef26cee9a | ||
|
|
29febb2bfd | ||
|
|
f24a60bfcd | ||
|
|
91ef6d9bc3 | ||
|
|
8ee1e5ece8 |
23
.github/workflows/build_mlir_cache.yml
vendored
23
.github/workflows/build_mlir_cache.yml
vendored
@@ -12,6 +12,29 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
- name: Free disk space
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
df -h
|
||||||
|
sudo apt-get remove -y '^dotnet-.*'
|
||||||
|
sudo apt-get remove -y '^llvm-.*'
|
||||||
|
sudo apt-get remove -y 'php.*'
|
||||||
|
sudo apt-get remove -y '^mongodb-.*'
|
||||||
|
sudo apt-get remove -y '^mysql-.*'
|
||||||
|
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
|
||||||
|
sudo apt-get autoremove -y
|
||||||
|
sudo apt-get clean
|
||||||
|
df -h
|
||||||
|
sudo rm -rf /usr/local/lib/android || true
|
||||||
|
sudo rm -rf /usr/share/dotnet || true
|
||||||
|
sudo rm -rf /opt/ghc || true
|
||||||
|
sudo rm -rf /usr/local/.ghcup || true
|
||||||
|
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
|
||||||
|
sudo docker system prune --all --volumes --force
|
||||||
|
sudo apt-get clean
|
||||||
|
sudo rm -rf /var/lib/apt/lists/*
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Cache MLIR build
|
- name: Cache MLIR build
|
||||||
id: cache-mlir
|
id: cache-mlir
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
23
.github/workflows/validate_operations.yml
vendored
23
.github/workflows/validate_operations.yml
vendored
@@ -29,6 +29,29 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
- name: Free disk space
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
df -h
|
||||||
|
sudo apt-get remove -y '^dotnet-.*'
|
||||||
|
sudo apt-get remove -y '^llvm-.*'
|
||||||
|
sudo apt-get remove -y 'php.*'
|
||||||
|
sudo apt-get remove -y '^mongodb-.*'
|
||||||
|
sudo apt-get remove -y '^mysql-.*'
|
||||||
|
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
|
||||||
|
sudo apt-get autoremove -y
|
||||||
|
sudo apt-get clean
|
||||||
|
df -h
|
||||||
|
sudo rm -rf /usr/local/lib/android || true
|
||||||
|
sudo rm -rf /usr/share/dotnet || true
|
||||||
|
sudo rm -rf /opt/ghc || true
|
||||||
|
sudo rm -rf /usr/local/.ghcup || true
|
||||||
|
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
|
||||||
|
sudo docker system prune --all --volumes --force
|
||||||
|
sudo apt-get clean
|
||||||
|
sudo rm -rf /var/lib/apt/lists/*
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -700,6 +700,7 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
let local_memory = core.load::<u8>(r1_val, imm_len)?;
|
let local_memory = core.load::<u8>(r1_val, imm_len)?;
|
||||||
let tmp = local_memory[0].to_vec();
|
let tmp = local_memory[0].to_vec();
|
||||||
core.execute_store(rd_val, tmp.as_slice());
|
core.execute_store(rd_val, tmp.as_slice());
|
||||||
|
TRACER.lock().unwrap().post_lmv(cores, data);
|
||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ impl Instruction {
|
|||||||
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
|
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn dump(&self) {
|
||||||
|
eprintln!("\t{}", functor_to_name(self.functor as usize));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Instructions = Vec<Instruction>;
|
pub type Instructions = Vec<Instruction>;
|
||||||
|
|||||||
@@ -71,18 +71,27 @@ pub fn json_to_instruction(
|
|||||||
inst_builder,
|
inst_builder,
|
||||||
inst_data_builder,
|
inst_data_builder,
|
||||||
json,
|
json,
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! json_str {
|
macro_rules! json_str {
|
||||||
($json:ident , $value:literal) => {
|
($json:ident , $value:literal) => {
|
||||||
$json.get($value).context(concat![$value, " field not present"])?.as_str().context(concat![$value, " field not str"])?
|
$json
|
||||||
|
.get($value)
|
||||||
|
.context(concat![$value, " field not present"])?
|
||||||
|
.as_str()
|
||||||
|
.context(concat![$value, " field not str"])?
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! json_i64 {
|
macro_rules! json_i64 {
|
||||||
($json:ident , $value:literal) => {
|
($json:ident , $value:literal) => {
|
||||||
$json.get($value).context(concat![$value, " field not present"])?.as_i64().context(concat![$value, " field not i64"])?
|
$json
|
||||||
|
.get($value)
|
||||||
|
.context(concat![$value, " field not present"])?
|
||||||
|
.as_i64()
|
||||||
|
.context(concat![$value, " field not i64"])?
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
|
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
|
||||||
};
|
};
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod instruction_set;
|
pub mod instruction_set;
|
||||||
@@ -131,6 +131,16 @@ impl Executable {
|
|||||||
pub fn cpu_mut(&mut self) -> &mut CPU {
|
pub fn cpu_mut(&mut self) -> &mut CPU {
|
||||||
&mut self.cpu
|
&mut self.cpu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dump(&self) {
|
||||||
|
let core_instructions = &self.core_instructions;
|
||||||
|
for (i, core_instruction) in core_instructions.iter().enumerate() {
|
||||||
|
eprintln!("INST OF CORE {}:", i);
|
||||||
|
for inst in &core_instruction.instructions {
|
||||||
|
inst.dump();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {
|
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {
|
||||||
|
|||||||
@@ -1123,7 +1123,7 @@ impl Trace {
|
|||||||
if prefix == "Pre" {
|
if prefix == "Pre" {
|
||||||
writeln!(
|
writeln!(
|
||||||
file,
|
file,
|
||||||
"Inst: lvm {} {} {} {{ {} {} }}",
|
"Inst: lmv {} {} {} {{ {} {} }}",
|
||||||
rd, r1, imm_len, offset_select, offset_value
|
rd, r1, imm_len, offset_select, offset_value
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
@@ -1141,13 +1141,15 @@ impl Trace {
|
|||||||
);
|
);
|
||||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||||
let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
|
let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
|
||||||
let core_memory = core.load::<u8>(r1_val, imm_len).unwrap();
|
let core_memory = core
|
||||||
let global_memory = host.load::<u8>(rd_val, imm_len).unwrap();
|
.reserve_load(r1_val, imm_len).unwrap()
|
||||||
|
.reserve_load(rd_val, imm_len).unwrap()
|
||||||
|
.execute_load::<u8>().unwrap();
|
||||||
writeln!(file, "{} Memory:", prefix);
|
writeln!(file, "{} Memory:", prefix);
|
||||||
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
|
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
|
||||||
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
|
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
|
||||||
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
|
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
|
||||||
pretty_print::print_slice::<_,f32>(file, global_memory[0], 30);
|
pretty_print::print_slice::<_,f32>(file, core_memory[1], 30);
|
||||||
|
|
||||||
if prefix == "Post" {
|
if prefix == "Post" {
|
||||||
writeln!(file, "\n###############################################\n");
|
writeln!(file, "\n###############################################\n");
|
||||||
|
|||||||
Submodule onnx-mlir updated: 84cedd1d69...82018d7ce5
@@ -10,14 +10,13 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
|||||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||||
|
|
||||||
add_subdirectory(Dialect)
|
add_subdirectory(Common)
|
||||||
add_subdirectory(Compiler)
|
add_subdirectory(Compiler)
|
||||||
add_subdirectory(Conversion)
|
add_subdirectory(Conversion)
|
||||||
add_subdirectory(Common)
|
add_subdirectory(Dialect)
|
||||||
|
|
||||||
add_onnx_mlir_library(OMPIMAccel
|
add_onnx_mlir_library(OMPIMAccel
|
||||||
PimAccelerator.cpp
|
PimAccelerator.cpp
|
||||||
Transforms/PimBufferizationPass.cpp
|
|
||||||
Pass/CountInstructionPass.cpp
|
Pass/CountInstructionPass.cpp
|
||||||
Pass/EmitPimJsonPass.cpp
|
Pass/EmitPimJsonPass.cpp
|
||||||
Pass/MessagePass.cpp
|
Pass/MessagePass.cpp
|
||||||
|
|||||||
@@ -21,9 +21,11 @@
|
|||||||
#include "src/Compiler/CompilerPasses.hpp"
|
#include "src/Compiler/CompilerPasses.hpp"
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
#include "src/Compiler/CompilerUtils.hpp"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
MemEntry* PimMemory::gatherMemEntry(Value value) {
|
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||||
auto type = cast<ShapedType>(value.getType());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
assert("Only static shape is supported" && type.hasStaticShape());
|
assert("Only static shape is supported" && type.hasStaticShape());
|
||||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
@@ -31,7 +33,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
|
|||||||
return &memEntries.emplace_back(memEntry, value).first;
|
return &memEntries.emplace_back(memEntry, value).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) {
|
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||||
memEntry.address = firstAvailableAddress;
|
memEntry.address = firstAvailableAddress;
|
||||||
firstAvailableAddress += memEntry.size;
|
firstAvailableAddress += memEntry.size;
|
||||||
// Alignment
|
// Alignment
|
||||||
@@ -59,7 +61,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (Value arg : funcOp.getArguments())
|
for (mlir::Value arg : funcOp.getArguments())
|
||||||
gatherMemEntry(arg);
|
gatherMemEntry(arg);
|
||||||
|
|
||||||
allocateCore(funcOp);
|
allocateCore(funcOp);
|
||||||
@@ -73,7 +75,7 @@ void PimMemory::allocateCore(Operation* op) {
|
|||||||
allocateMemoryForValue(value, memEntry);
|
allocateMemoryForValue(value, memEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
MemEntry PimMemory::getMemEntry(Value value) const {
|
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||||
auto iter = globalMemEntriesMap.find(value);
|
auto iter = globalMemEntriesMap.find(value);
|
||||||
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
||||||
return iter->second;
|
return iter->second;
|
||||||
@@ -83,7 +85,8 @@ PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
|||||||
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t PimAcceleratorMemory::getValueAddress(Value value) const {
|
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||||
|
size_t offset = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
auto definingOp = value.getDefiningOp();
|
auto definingOp = value.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
@@ -101,12 +104,18 @@ size_t PimAcceleratorMemory::getValueAddress(Value value) const {
|
|||||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
||||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
||||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
||||||
|
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
|
||||||
|
size_t localOffset = subviewOffsets[i];
|
||||||
|
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
|
||||||
|
localOffset *= subviewSizes[j];
|
||||||
|
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
|
||||||
|
}
|
||||||
value = source;
|
value = source;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return memEntriesMap.at(value).address;
|
return memEntriesMap.at(value).address + offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
json::Object PimCodeGen::createEmptyOffset() {
|
json::Object PimCodeGen::createEmptyOffset() {
|
||||||
@@ -144,15 +153,20 @@ void PimCodeGen::setupRdRs1Rs2(
|
|||||||
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::emitMemCopyOp(
|
void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||||
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const {
|
size_t rdAddr,
|
||||||
|
size_t rdOffset,
|
||||||
|
size_t rs1Addr,
|
||||||
|
size_t rs1Offset,
|
||||||
|
size_t size,
|
||||||
|
StringRef sizeFieldName) const {
|
||||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = opName;
|
json["op"] = opName;
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["size"] = size;
|
json[sizeFieldName] = size;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
@@ -200,6 +214,16 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
|
|||||||
storeOp.getSize());
|
storeOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
|
||||||
|
emitMemCopyOp("lmv",
|
||||||
|
memory.getValueAddress(lmvOp.getDst()),
|
||||||
|
lmvOp.getDstOffset(),
|
||||||
|
memory.getValueAddress(lmvOp.getSrc()),
|
||||||
|
lmvOp.getSrcOffset(),
|
||||||
|
lmvOp.getSize(),
|
||||||
|
"len");
|
||||||
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
||||||
emitCommunicationOp(
|
emitCommunicationOp(
|
||||||
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
|
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
|
||||||
@@ -343,7 +367,6 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||||
static OnnxMlirCompilerErrorCodes
|
static OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
|
|
||||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||||
std::error_code errorCode;
|
std::error_code errorCode;
|
||||||
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||||
@@ -400,6 +423,12 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenLoadOp(loadOp);
|
coreCodeGen.codeGenLoadOp(loadOp);
|
||||||
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
||||||
coreCodeGen.codeGenStoreOp(storeOp);
|
coreCodeGen.codeGenStoreOp(storeOp);
|
||||||
|
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
|
||||||
|
coreCodeGen.codeGenLmvOp(lmvOp);
|
||||||
|
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
||||||
|
coreCodeGen.codeGenReceiveOp(receiveOp);
|
||||||
|
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||||
|
coreCodeGen.codeGenSendOp(sendOp);
|
||||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
||||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||||
@@ -412,10 +441,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenVMaxOp(vmaxOp);
|
coreCodeGen.codeGenVMaxOp(vmaxOp);
|
||||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||||
coreCodeGen.codeGenVReluOp(vreluOp);
|
coreCodeGen.codeGenVReluOp(vreluOp);
|
||||||
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
|
||||||
coreCodeGen.codeGenReceiveOp(receiveOp);
|
|
||||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
|
||||||
coreCodeGen.codeGenSendOp(sendOp);
|
|
||||||
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
|
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
|
||||||
// TODO: Implement somehow?
|
// TODO: Implement somehow?
|
||||||
op.emitWarning("Operation is not yet supported in code generation");
|
op.emitWarning("Operation is not yet supported in code generation");
|
||||||
@@ -539,7 +564,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
|||||||
|
|
||||||
json::Array outputsAddresses;
|
json::Array outputsAddresses;
|
||||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||||
for (Value output : returnOp.getOperands())
|
for (mlir::Value output : returnOp.getOperands())
|
||||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||||
|
|
||||||
|
|||||||
@@ -9,47 +9,41 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
using namespace llvm;
|
|
||||||
using namespace mlir;
|
|
||||||
using Value = mlir::Value;
|
|
||||||
using Type = mlir::Type;
|
|
||||||
using FunctionType = mlir::FunctionType;
|
|
||||||
|
|
||||||
struct MemEntry {
|
struct MemEntry {
|
||||||
size_t address;
|
size_t address;
|
||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimMemory {
|
class PimMemory {
|
||||||
SmallVector<std::pair<MemEntry, Value>, 32> memEntries;
|
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||||
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||||
|
|
||||||
size_t maxSize = 0; // 0 for unbounded memory
|
size_t maxSize = 0; // 0 for unbounded memory
|
||||||
size_t startAddress = 0;
|
size_t startAddress = 0;
|
||||||
size_t minAlignment = 4;
|
size_t minAlignment = 4;
|
||||||
size_t firstAvailableAddress = 0;
|
size_t firstAvailableAddress = 0;
|
||||||
|
|
||||||
MemEntry* gatherMemEntry(Value value);
|
MemEntry* gatherMemEntry(mlir::Value value);
|
||||||
void allocateMemoryForValue(Value value, MemEntry& memEntry);
|
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
|
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
|
||||||
: globalMemEntriesMap(globalMemEntriesMap) {}
|
: globalMemEntriesMap(globalMemEntriesMap) {}
|
||||||
|
|
||||||
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp);
|
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||||
void allocateCore(Operation* op);
|
void allocateCore(mlir::Operation* op);
|
||||||
|
|
||||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||||
MemEntry getMemEntry(Value value) const;
|
MemEntry getMemEntry(mlir::Value value) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimAcceleratorMemory {
|
class PimAcceleratorMemory {
|
||||||
public:
|
public:
|
||||||
SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
|
||||||
PimMemory hostMem;
|
PimMemory hostMem;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallDenseMap<size_t, PimMemory> deviceMem;
|
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimAcceleratorMemory()
|
PimAcceleratorMemory()
|
||||||
@@ -57,15 +51,15 @@ public:
|
|||||||
|
|
||||||
PimMemory getOrCreateDeviceMem(size_t id);
|
PimMemory getOrCreateDeviceMem(size_t id);
|
||||||
|
|
||||||
size_t getValueAddress(Value value) const;
|
size_t getValueAddress(mlir::Value value) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimCodeGen {
|
class PimCodeGen {
|
||||||
PimAcceleratorMemory& memory;
|
PimAcceleratorMemory& memory;
|
||||||
raw_fd_ostream& coreFileStream;
|
llvm::raw_fd_ostream& coreFileStream;
|
||||||
|
|
||||||
static json::Object createEmptyOffset();
|
static llvm::json::Object createEmptyOffset();
|
||||||
void emitInstruction(json::Object instruction) const;
|
void emitInstruction(llvm::json::Object instruction) const;
|
||||||
|
|
||||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||||
@@ -73,17 +67,23 @@ class PimCodeGen {
|
|||||||
void setupRdRs1Rs2(
|
void setupRdRs1Rs2(
|
||||||
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
|
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
|
||||||
|
|
||||||
void
|
void emitMemCopyOp(mlir::StringRef opName,
|
||||||
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const;
|
size_t rdAddr,
|
||||||
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
|
size_t rdOffset,
|
||||||
|
size_t rs1Addr,
|
||||||
|
size_t rs1Offset,
|
||||||
|
size_t size,
|
||||||
|
mlir::StringRef sizeFieldName = "size") const;
|
||||||
|
void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
|
||||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson)
|
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
||||||
: memory(memory), coreFileStream(coreJson) {}
|
: memory(memory), coreFileStream(coreJson) {}
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
|
||||||
|
void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const;
|
||||||
|
|
||||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
|
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
|
||||||
void codeGenSendOp(pim::PimSendOp sendOp) const;
|
void codeGenSendOp(pim::PimSendOp sendOp) const;
|
||||||
@@ -97,6 +97,6 @@ public:
|
|||||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName);
|
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -25,9 +25,96 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||||
|
|
||||||
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||||
ONNXGemmOpTile(MLIRContext* ctx)
|
GemmToManyGemv(MLIRContext* ctx)
|
||||||
: OpConversionPattern(ctx) {}
|
: OpConversionPattern(ctx, 2) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Location loc = gemmOp.getLoc();
|
||||||
|
Value a = adaptor.getA();
|
||||||
|
Value b = adaptor.getB();
|
||||||
|
Value c = adaptor.getC();
|
||||||
|
|
||||||
|
assert("A should have been transposed already" && !adaptor.getTransA());
|
||||||
|
|
||||||
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
|
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||||
|
|
||||||
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
|
||||||
|
// Only decompose when there are multiple rows to split
|
||||||
|
if (numOutRows <= 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
RankedTensorType cType = nullptr;
|
||||||
|
bool cHasNumOutRows = false;
|
||||||
|
if (hasC) {
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||||
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<Value> gemvOps;
|
||||||
|
gemvOps.reserve(numOutRows);
|
||||||
|
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||||
|
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||||
|
|
||||||
|
Value cSlice = c;
|
||||||
|
if (hasC) {
|
||||||
|
if (cHasNumOutRows) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||||
|
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
||||||
|
outRowType,
|
||||||
|
aSlice,
|
||||||
|
b,
|
||||||
|
cSlice,
|
||||||
|
gemmOp.getAlphaAttr(),
|
||||||
|
gemmOp.getBetaAttr(),
|
||||||
|
gemmOp.getTransAAttr(),
|
||||||
|
gemmOp.getTransBAttr());
|
||||||
|
gemvOps.push_back(gemvOp.getY());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concatComputeOp =
|
||||||
|
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||||
|
|
||||||
|
auto* concatBlock = new Block();
|
||||||
|
for (auto gemvOp : gemvOps)
|
||||||
|
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||||
|
concatComputeOp.getBody().push_back(concatBlock);
|
||||||
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
|
auto blockArgs = concatBlock->getArguments();
|
||||||
|
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
||||||
|
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
||||||
|
|
||||||
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||||
|
GemvToSpatialCompute(MLIRContext* ctx)
|
||||||
|
: OpConversionPattern(ctx, 1) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||||
@@ -50,12 +137,16 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
|||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
cType = cast<RankedTensorType>(c.getType());
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
assert("Only support 2 tensor for C" && cType.getRank() == 2);
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||||
|
|
||||||
|
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
|
||||||
|
// Not a gemv
|
||||||
|
return failure();
|
||||||
|
|
||||||
if (transA) {
|
if (transA) {
|
||||||
auto aShape = aType.getShape();
|
auto aShape = aType.getShape();
|
||||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||||
@@ -169,9 +260,20 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
|||||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(gemmOp);
|
auto concatComputeOp =
|
||||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
|
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||||
rewriter.replaceOp(gemmOp, concatOp);
|
|
||||||
|
auto* concatBlock = new Block();
|
||||||
|
for (auto outHSlice : outHSlices)
|
||||||
|
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||||
|
concatComputeOp.getBody().push_back(concatBlock);
|
||||||
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
|
auto blockArgs = concatBlock->getArguments();
|
||||||
|
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
||||||
|
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
||||||
|
|
||||||
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,8 +412,9 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<ONNXGemmOpTile>(ctx);
|
patterns.insert<GemmToManyGemv>(ctx);
|
||||||
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ namespace onnx_mlir {
|
|||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
|
|
||||||
|
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
else {
|
else {
|
||||||
populateTilingConvOpPattern(patterns, ctx);
|
populateTilingConvOpPattern(patterns, ctx);
|
||||||
populatePoolingTilingPattern(patterns, ctx);
|
populatePoolingTilingPattern(patterns, ctx);
|
||||||
populateTilingGemmOpPattern(patterns, ctx);
|
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|||||||
@@ -19,10 +19,9 @@
|
|||||||
#include "SpatialToPIMPass.hpp"
|
#include "SpatialToPIMPass.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir;
|
||||||
namespace onnx_mlir {
|
using namespace pim;
|
||||||
|
using namespace spat_to_pim;
|
||||||
namespace pim {
|
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnOperation() {
|
void SpatialToPIMPass::runOnOperation() {
|
||||||
coreId = 1;
|
coreId = 1;
|
||||||
@@ -258,9 +257,16 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
|||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||||
for (auto returnValue : returnOp->getOperands()) {
|
for (auto returnValue : returnOp->getOperands()) {
|
||||||
auto newOutputTensor =
|
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||||
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||||
outputTensors.push_back(newOutputTensor);
|
assert(!returnValueDefiningOp->hasAttr("weightAlways"));
|
||||||
|
outputTensors.push_back(returnValue);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto newOutputTensor =
|
||||||
|
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
||||||
|
outputTensors.push_back(newOutputTensor);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,6 +415,7 @@ void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
|
|||||||
if (!computeUser) {
|
if (!computeUser) {
|
||||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
||||||
if (!reshapeOp) {
|
if (!reshapeOp) {
|
||||||
|
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
|
||||||
resultUse.getOwner()->dump();
|
resultUse.getOwner()->dump();
|
||||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
||||||
}
|
}
|
||||||
@@ -479,7 +486,3 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
|||||||
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
|
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pim
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace pim {
|
namespace spat_to_pim {
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
|
||||||
|
|
||||||
@@ -53,8 +53,8 @@ private:
|
|||||||
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace pim
|
} // namespace spat_to_pim
|
||||||
|
|
||||||
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<pim::SpatialToPIMPass>(); }
|
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<spat_to_pim::SpatialToPIMPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
add_onnx_mlir_dialect(Pim pim)
|
add_onnx_mlir_dialect(Pim pim)
|
||||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||||
|
|
||||||
|
add_subdirectory(Transforms/Bufferization)
|
||||||
|
|
||||||
add_onnx_mlir_library(PimOps
|
add_onnx_mlir_library(PimOps
|
||||||
|
PimOps.hpp
|
||||||
PimOps.cpp
|
PimOps.cpp
|
||||||
Transforms/PimBufferizableOpInterface.cpp
|
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
OMPimIncGen
|
OMPimIncGen
|
||||||
|
|||||||
@@ -14,20 +14,13 @@ def PimDialect : Dialect {
|
|||||||
let cppNamespace = "::onnx_mlir::pim";
|
let cppNamespace = "::onnx_mlir::pim";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base class for Pim dialect operations. This operation inherits from the
|
|
||||||
// base `Op` class in OpBase.td, and provides:
|
|
||||||
// * The parent dialect of the operation.
|
|
||||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
|
||||||
// * A list of traits for the operation.
|
|
||||||
class PimOp<string mnemonic, list<Trait> traits = []> :
|
class PimOp<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<PimDialect, mnemonic, traits>;
|
Op<PimDialect, mnemonic, traits>;
|
||||||
|
|
||||||
def PimTensor :
|
def PimTensor :
|
||||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
// Communication
|
||||||
// Communication operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimSendOp: PimOp<"send", []> {
|
def PimSendOp: PimOp<"send", []> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@@ -63,9 +56,7 @@ def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
// Core
|
||||||
// Core operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
||||||
|
|
||||||
@@ -81,9 +72,7 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
// Memory
|
||||||
// Memory Operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimConstantOp: PimOp<"constant", []> {
|
def PimConstantOp: PimOp<"constant", []> {
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -157,9 +146,36 @@ def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||||
// Core.Compute operations
|
let description = [{
|
||||||
//===----------------------------------------------------------------------===//
|
Copy a memory region from and to the same memory
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor: $dst,
|
||||||
|
PimTensor: $src,
|
||||||
|
I32Attr: $dstOffset,
|
||||||
|
I32Attr: $srcOffset,
|
||||||
|
I32Attr: $size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor: $dstOut
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getDstMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computation
|
||||||
|
|
||||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|||||||
22
src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt
Normal file
22
src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
||||||
|
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||||
|
add_public_tablegen_target(PimBufferizationIncGen)
|
||||||
|
|
||||||
|
add_onnx_mlir_library(OMPimBufferization
|
||||||
|
PimBufferizationPass.hpp
|
||||||
|
PimBufferizationPass.cpp
|
||||||
|
OpBufferizationInterfaces.hpp
|
||||||
|
OpBufferizationInterfaces.cpp
|
||||||
|
Common.hpp
|
||||||
|
Common.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
PimBufferizationIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
OMPIMCommon
|
||||||
|
PimOps
|
||||||
|
|
||||||
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
|
${PIM_INCLUDE_PATH}
|
||||||
|
)
|
||||||
9
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp
Normal file
9
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||||
|
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||||
|
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||||
|
return builder.getI32IntegerAttr(sizeInBytes);
|
||||||
|
}
|
||||||
13
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp
Normal file
13
src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace pim {
|
||||||
|
|
||||||
|
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref);
|
||||||
|
|
||||||
|
} // namespace pim
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "OpBufferizationInterfaces.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace bufferization;
|
using namespace bufferization;
|
||||||
@@ -173,7 +172,7 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
@@ -9,7 +9,7 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
void registerOpBufferizationInterfaces(DialectRegistry& registry);
|
||||||
|
|
||||||
} // namespace pim
|
} // namespace pim
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
#ifndef PIM_BUFFERIZATION
|
||||||
|
#define PIM_BUFFERIZATION
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "mlir/IR/PatternBase.td"
|
||||||
|
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||||
|
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def memrefCopyToPimMemCopyOp : Pat<
|
||||||
|
(CopyOp $src, $dst),
|
||||||
|
(PimMemCopyOp $dst, $src,
|
||||||
|
ConstantAttr<I32Attr, "0">,
|
||||||
|
ConstantAttr<I32Attr, "0">,
|
||||||
|
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
|
||||||
|
(returnType $dst))
|
||||||
|
>;
|
||||||
|
|
||||||
|
#endif // PIM_BUFFERIZATION
|
||||||
@@ -5,23 +5,18 @@
|
|||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
#include "Common/PIMCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "PimBufferizationPass.hpp"
|
#include "PimBufferizationPass.hpp"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir;
|
||||||
namespace onnx_mlir {
|
using namespace pim;
|
||||||
|
|
||||||
namespace pim {
|
|
||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
auto moduleOp = getOperation();
|
auto moduleOp = getOperation();
|
||||||
|
|
||||||
// Do One-Shot-Bufferization
|
// One-Shot-Bufferization
|
||||||
bufferization::OneShotBufferizationOptions options;
|
bufferization::OneShotBufferizationOptions options;
|
||||||
options.allowUnknownOps = true;
|
options.allowUnknownOps = true;
|
||||||
bufferization::BufferizationState state;
|
bufferization::BufferizationState state;
|
||||||
@@ -30,7 +25,19 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove toTensor operations
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
ConversionTarget target(*ctx);
|
||||||
|
target.addLegalDialect<PimDialect>();
|
||||||
|
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateWithGenerated(patterns);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove toTensor operations: leave memrefs instead
|
||||||
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
||||||
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
||||||
toTensorOp.erase();
|
toTensorOp.erase();
|
||||||
@@ -63,8 +70,8 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
MLIRContext* ctx = funcOp.getContext();
|
MLIRContext* ctx = funcOp.getContext();
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); })
|
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||||
&& !getGlobalOp->getUsers().empty();
|
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||||
if (isAlwaysWeight) {
|
if (isAlwaysWeight) {
|
||||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
@@ -73,7 +80,3 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pim
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "Dialect/PIM/PimOps.hpp"
|
||||||
|
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
@@ -9,6 +11,8 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
|
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||||
|
|
||||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
@@ -64,7 +64,7 @@ void PimAccelerator::registerDialects(DialectRegistry& registry) const {
|
|||||||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
||||||
pim::registerBufferizableOpInterfaceExternalModels(registry);
|
pim::registerOpBufferizationInterfaces(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimAccelerator::registerPasses(int optLevel) const {
|
void PimAccelerator::registerPasses(int optLevel) const {
|
||||||
|
|||||||
Binary file not shown.
BIN
validation/operations/gemm/gemm.onnx
Normal file
BIN
validation/operations/gemm/gemm.onnx
Normal file
Binary file not shown.
BIN
validation/operations/gemv/constant/gemv_constant.onnx
Normal file
BIN
validation/operations/gemv/constant/gemv_constant.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user