Compare commits

...

10 Commits

Author SHA1 Message Date
NiccoloN
584ca0b3c2 fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 1m7s
Validate Operations / build-mlir-cache (push) Failing after 2m11s
Validate Operations / validate (push) Has been skipped
2026-03-09 14:30:41 +01:00
NiccoloN
1348bb1c97 generic gemm now works :) 2026-03-06 18:23:27 +01:00
NiccoloN
825188cc89 Merge remote-tracking branch 'origin/main' 2026-03-06 15:44:30 +01:00
NiccoloN
7202a4317d add free disk space step to CI 2026-03-06 15:44:23 +01:00
ilgeco
d4efa64b96 pim-simulator auto-format 2026-03-04 20:00:01 +01:00
ilgeco
fef26cee9a pim-simulator unwrap on failed json parsing 2026-03-04 19:59:16 +01:00
ilgeco
29febb2bfd pim-simulator dump instruction 2026-03-04 19:57:24 +01:00
ilgeco
f24a60bfcd pim-simulato trace end of lmv 2026-03-04 19:56:59 +01:00
ilgeco
91ef6d9bc3 pim-simulator dump inst function 2026-03-04 19:56:30 +01:00
NiccoloN
8ee1e5ece8 implement mem copy codgen (lmv)
add more gemv/gemm tests
refactor
2026-03-04 18:04:48 +01:00
34 changed files with 404 additions and 118 deletions

View File

@@ -12,6 +12,29 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Free disk space
if: runner.os == 'Linux'
run: |
df -h
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y '^llvm-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y '^mongodb-.*'
sudo apt-get remove -y '^mysql-.*'
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
sudo apt-get autoremove -y
sudo apt-get clean
df -h
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
sudo docker system prune --all --volumes --force
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
df -h
- name: Cache MLIR build - name: Cache MLIR build
id: cache-mlir id: cache-mlir
uses: actions/cache@v4 uses: actions/cache@v4

View File

@@ -29,6 +29,29 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Free disk space
if: runner.os == 'Linux'
run: |
df -h
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y '^llvm-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y '^mongodb-.*'
sudo apt-get remove -y '^mysql-.*'
sudo apt-get remove -y azure-cli google-cloud-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri
sudo apt-get autoremove -y
sudo apt-get clean
df -h
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true
sudo rm -rf /opt/hostedtoolcache/CodeQL || true
sudo docker system prune --all --volumes --force
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
df -h
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:

View File

@@ -700,6 +700,7 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
let local_memory = core.load::<u8>(r1_val, imm_len)?; let local_memory = core.load::<u8>(r1_val, imm_len)?;
let tmp = local_memory[0].to_vec(); let tmp = local_memory[0].to_vec();
core.execute_store(rd_val, tmp.as_slice()); core.execute_store(rd_val, tmp.as_slice());
TRACER.lock().unwrap().post_lmv(cores, data);
Ok(InstructionStatus::Completed) Ok(InstructionStatus::Completed)
} }

View File

@@ -46,6 +46,10 @@ impl Instruction {
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1)) .with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
.unwrap() .unwrap()
} }
pub(crate) fn dump(&self) {
eprintln!("\t{}", functor_to_name(self.functor as usize));
}
} }
pub type Instructions = Vec<Instruction>; pub type Instructions = Vec<Instruction>;

View File

@@ -71,18 +71,27 @@ pub fn json_to_instruction(
inst_builder, inst_builder,
inst_data_builder, inst_data_builder,
json, json,
); )
.unwrap();
} }
macro_rules! json_str { macro_rules! json_str {
($json:ident , $value:literal) => { ($json:ident , $value:literal) => {
$json.get($value).context(concat![$value, " field not present"])?.as_str().context(concat![$value, " field not str"])? $json
.get($value)
.context(concat![$value, " field not present"])?
.as_str()
.context(concat![$value, " field not str"])?
}; };
} }
macro_rules! json_i64 { macro_rules! json_i64 {
($json:ident , $value:literal) => { ($json:ident , $value:literal) => {
$json.get($value).context(concat![$value, " field not present"])?.as_i64().context(concat![$value, " field not i64"])? $json
.get($value)
.context(concat![$value, " field not present"])?
.as_i64()
.context(concat![$value, " field not i64"])?
}; };
} }

View File

@@ -1,7 +1,7 @@
#![allow(unused)] #![allow(unused)]
use crate::{ use crate::{
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
}; };
pub mod cpu; pub mod cpu;
pub mod instruction_set; pub mod instruction_set;
@@ -131,6 +131,16 @@ impl Executable {
pub fn cpu_mut(&mut self) -> &mut CPU { pub fn cpu_mut(&mut self) -> &mut CPU {
&mut self.cpu &mut self.cpu
} }
pub fn dump(&self) {
let core_instructions = &self.core_instructions;
for (i, core_instruction) in core_instructions.iter().enumerate() {
eprintln!("INST OF CORE {}:", i);
for inst in &core_instruction.instructions {
inst.dump();
}
}
}
} }
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) { fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {

View File

@@ -1123,7 +1123,7 @@ impl Trace {
if prefix == "Pre" { if prefix == "Pre" {
writeln!( writeln!(
file, file,
"Inst: lvm {} {} {} {{ {} {} }}", "Inst: lmv {} {} {} {{ {} {} }}",
rd, r1, imm_len, offset_select, offset_value rd, r1, imm_len, offset_select, offset_value
); );
} else { } else {
@@ -1141,13 +1141,15 @@ impl Trace {
); );
let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let rd_val = add_offset_rd(rd_val, offset_select, offset_value); let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
let core_memory = core.load::<u8>(r1_val, imm_len).unwrap(); let core_memory = core
let global_memory = host.load::<u8>(rd_val, imm_len).unwrap(); .reserve_load(r1_val, imm_len).unwrap()
.reserve_load(rd_val, imm_len).unwrap()
.execute_load::<u8>().unwrap();
writeln!(file, "{} Memory:", prefix); writeln!(file, "{} Memory:", prefix);
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,); writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30); pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,); writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
pretty_print::print_slice::<_,f32>(file, global_memory[0], 30); pretty_print::print_slice::<_,f32>(file, core_memory[1], 30);
if prefix == "Post" { if prefix == "Post" {
writeln!(file, "\n###############################################\n"); writeln!(file, "\n###############################################\n");

View File

@@ -10,14 +10,13 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT}) set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT}) set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
add_subdirectory(Dialect) add_subdirectory(Common)
add_subdirectory(Compiler) add_subdirectory(Compiler)
add_subdirectory(Conversion) add_subdirectory(Conversion)
add_subdirectory(Common) add_subdirectory(Dialect)
add_onnx_mlir_library(OMPIMAccel add_onnx_mlir_library(OMPIMAccel
PimAccelerator.cpp PimAccelerator.cpp
Transforms/PimBufferizationPass.cpp
Pass/CountInstructionPass.cpp Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp Pass/MessagePass.cpp

View File

@@ -21,9 +21,11 @@
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
MemEntry* PimMemory::gatherMemEntry(Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8; size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
@@ -31,7 +33,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) { void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress; memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size; firstAvailableAddress += memEntry.size;
// Alignment // Alignment
@@ -59,7 +61,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
} }
}); });
for (Value arg : funcOp.getArguments()) for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg); gatherMemEntry(arg);
allocateCore(funcOp); allocateCore(funcOp);
@@ -73,7 +75,7 @@ void PimMemory::allocateCore(Operation* op) {
allocateMemoryForValue(value, memEntry); allocateMemoryForValue(value, memEntry);
} }
MemEntry PimMemory::getMemEntry(Value value) const { MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value); auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end()); assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second; return iter->second;
@@ -83,7 +85,8 @@ PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second; return deviceMem.try_emplace(id, memEntriesMap).first->second;
} }
size_t PimAcceleratorMemory::getValueAddress(Value value) const { size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
size_t offset = 0;
while (true) { while (true) {
auto definingOp = value.getDefiningOp(); auto definingOp = value.getDefiningOp();
if (!definingOp) if (!definingOp)
@@ -101,12 +104,18 @@ size_t PimAcceleratorMemory::getValueAddress(Value value) const {
auto subviewSizes = subviewDefiningOp.getStaticSizes(); auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides(); auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides)); assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
size_t localOffset = subviewOffsets[i];
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
localOffset *= subviewSizes[j];
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
}
value = source; value = source;
} }
else else
break; break;
} }
return memEntriesMap.at(value).address; return memEntriesMap.at(value).address + offset;
} }
json::Object PimCodeGen::createEmptyOffset() { json::Object PimCodeGen::createEmptyOffset() {
@@ -144,15 +153,20 @@ void PimCodeGen::setupRdRs1Rs2(
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset); genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
} }
void PimCodeGen::emitMemCopyOp( void PimCodeGen::emitMemCopyOp(StringRef opName,
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const { size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset); setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
json::Object json; json::Object json;
json["op"] = opName; json["op"] = opName;
json["rd"] = 0; json["rd"] = 0;
json["rs1"] = 1; json["rs1"] = 1;
json["size"] = size; json[sizeFieldName] = size;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
@@ -200,6 +214,16 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
storeOp.getSize()); storeOp.getSize());
} }
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
emitMemCopyOp("lmv",
memory.getValueAddress(lmvOp.getDst()),
lmvOp.getDstOffset(),
memory.getValueAddress(lmvOp.getSrc()),
lmvOp.getSrcOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const { void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
emitCommunicationOp( emitCommunicationOp(
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize()); "recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
@@ -343,7 +367,6 @@ std::string getMemorySizeAsString(size_t size) {
/// Write global constant data into a binary memory image at their allocated addresses. /// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str(); auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode; std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None); raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
@@ -400,6 +423,12 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLoadOp(loadOp); coreCodeGen.codeGenLoadOp(loadOp);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp); coreCodeGen.codeGenStoreOp(storeOp);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true); coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op)) else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
@@ -412,10 +441,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVMaxOp(vmaxOp); coreCodeGen.codeGenVMaxOp(vmaxOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op)) else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp); coreCodeGen.codeGenVReluOp(vreluOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) { else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
// TODO: Implement somehow? // TODO: Implement somehow?
op.emitWarning("Operation is not yet supported in code generation"); op.emitWarning("Operation is not yet supported in code generation");
@@ -539,7 +564,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
json::Array outputsAddresses; json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>()) for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (Value output : returnOp.getOperands()) for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output)); outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses); configJson["outputs_addresses"] = std::move(outputsAddresses);

View File

@@ -9,47 +9,41 @@
namespace onnx_mlir { namespace onnx_mlir {
using namespace llvm;
using namespace mlir;
using Value = mlir::Value;
using Type = mlir::Type;
using FunctionType = mlir::FunctionType;
struct MemEntry { struct MemEntry {
size_t address; size_t address;
size_t size; size_t size;
}; };
class PimMemory { class PimMemory {
SmallVector<std::pair<MemEntry, Value>, 32> memEntries; llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0; size_t startAddress = 0;
size_t minAlignment = 4; size_t minAlignment = 4;
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(Value value); MemEntry* gatherMemEntry(mlir::Value value);
void allocateMemoryForValue(Value value, MemEntry& memEntry); void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
public: public:
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap) PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {} : globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp); void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(Operation* op); void allocateCore(mlir::Operation* op);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; } size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(Value value) const; MemEntry getMemEntry(mlir::Value value) const;
}; };
class PimAcceleratorMemory { class PimAcceleratorMemory {
public: public:
SmallDenseMap<Value, MemEntry, 32> memEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
PimMemory hostMem; PimMemory hostMem;
private: private:
SmallDenseMap<size_t, PimMemory> deviceMem; llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
public: public:
PimAcceleratorMemory() PimAcceleratorMemory()
@@ -57,15 +51,15 @@ public:
PimMemory getOrCreateDeviceMem(size_t id); PimMemory getOrCreateDeviceMem(size_t id);
size_t getValueAddress(Value value) const; size_t getValueAddress(mlir::Value value) const;
}; };
class PimCodeGen { class PimCodeGen {
PimAcceleratorMemory& memory; PimAcceleratorMemory& memory;
raw_fd_ostream& coreFileStream; llvm::raw_fd_ostream& coreFileStream;
static json::Object createEmptyOffset(); static llvm::json::Object createEmptyOffset();
void emitInstruction(json::Object instruction) const; void emitInstruction(llvm::json::Object instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const; void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const; void setupRd(size_t rdAddress, size_t rdOffset) const;
@@ -73,17 +67,23 @@ class PimCodeGen {
void setupRdRs1Rs2( void setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const; size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
void void emitMemCopyOp(mlir::StringRef opName,
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const; size_t rdAddr,
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const; size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
mlir::StringRef sizeFieldName = "size") const;
void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const; void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public: public:
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson) PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {} : memory(memory), coreFileStream(coreJson) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const; void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const; void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
void codeGenSendOp(pim::PimSendOp sendOp) const; void codeGenSendOp(pim::PimSendOp sendOp) const;
@@ -97,6 +97,6 @@ public:
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
}; };
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName); OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -25,9 +25,96 @@ namespace onnx_mlir {
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> { struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
ONNXGemmOpTile(MLIRContext* ctx) GemmToManyGemv(MLIRContext* ctx)
: OpConversionPattern(ctx) {} : OpConversionPattern(ctx, 2) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location loc = gemmOp.getLoc();
Value a = adaptor.getA();
Value b = adaptor.getB();
Value c = adaptor.getC();
assert("A should have been transposed already" && !adaptor.getTransA());
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
const int64_t numOutRows = aType.getDimSize(0);
// Only decompose when there are multiple rows to split
if (numOutRows <= 1)
return failure();
RankedTensorType cType = nullptr;
bool cHasNumOutRows = false;
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows);
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c;
if (hasC) {
if (cHasNumOutRows) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
}
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
outRowType,
aSlice,
b,
cSlice,
gemmOp.getAlphaAttr(),
gemmOp.getBetaAttr(),
gemmOp.getTransAAttr(),
gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY());
}
auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
auto* concatBlock = new Block();
for (auto gemvOp : gemvOps)
concatBlock->addArgument(gemvOp.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
LogicalResult LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
@@ -50,12 +137,16 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) { if (hasC) {
cType = cast<RankedTensorType>(c.getType()); cType = cast<RankedTensorType>(c.getType());
assert("Only support 2 tensor for C" && cType.getRank() == 2); assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
} }
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
// Not a gemv
return failure();
if (transA) { if (transA) {
auto aShape = aType.getShape(); auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
@@ -169,9 +260,20 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
outHSlices.push_back(reduceComputeOp.getResult(0)); outHSlices.push_back(reduceComputeOp.getResult(0));
} }
rewriter.setInsertionPoint(gemmOp); auto concatComputeOp =
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices); rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
rewriter.replaceOp(gemmOp, concatOp);
auto* concatBlock = new Block();
for (auto outHSlice : outHSlices)
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
} }
@@ -310,8 +412,9 @@ private:
} }
}; };
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXGemmOpTile>(ctx); patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -24,8 +24,6 @@ namespace onnx_mlir {
namespace spatial { namespace spatial {
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
@@ -71,7 +69,7 @@ void ONNXToSpatialPass::runOnOperation() {
else { else {
populateTilingConvOpPattern(patterns, ctx); populateTilingConvOpPattern(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx); populatePoolingTilingPattern(patterns, ctx);
populateTilingGemmOpPattern(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx);
} }
populateONNXConcatToTensorConcatPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx);

View File

@@ -5,7 +5,7 @@ namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);

View File

@@ -19,10 +19,9 @@
#include "SpatialToPIMPass.hpp" #include "SpatialToPIMPass.hpp"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir;
namespace onnx_mlir { using namespace pim;
using namespace spat_to_pim;
namespace pim {
void SpatialToPIMPass::runOnOperation() { void SpatialToPIMPass::runOnOperation() {
coreId = 1; coreId = 1;
@@ -258,9 +257,16 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
outputTensors.reserve(returnOp->getNumOperands()); outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock()); rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) { for (auto returnValue : returnOp->getOperands()) {
auto newOutputTensor = Operation* returnValueDefiningOp = returnValue.getDefiningOp();
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType())); if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
outputTensors.push_back(newOutputTensor); assert(!returnValueDefiningOp->hasAttr("weightAlways"));
outputTensors.push_back(returnValue);
}
else {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
} }
} }
@@ -409,6 +415,7 @@ void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
if (!computeUser) { if (!computeUser) {
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner()); auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
if (!reshapeOp) { if (!reshapeOp) {
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
resultUse.getOwner()->dump(); resultUse.getOwner()->dump();
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp"); llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
} }
@@ -479,7 +486,3 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData()); rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
} }
} }
} // namespace pim
} // namespace onnx_mlir

View File

@@ -10,7 +10,7 @@
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace spat_to_pim {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" #include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
@@ -53,8 +53,8 @@ private:
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
}; };
} // namespace pim } // namespace spat_to_pim
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<pim::SpatialToPIMPass>(); } std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<spat_to_pim::SpatialToPIMPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,10 +1,11 @@
add_onnx_mlir_dialect(Pim pim) add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td) add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_onnx_mlir_library(PimOps add_onnx_mlir_library(PimOps
PimOps.hpp
PimOps.cpp PimOps.cpp
Transforms/PimBufferizableOpInterface.cpp
DEPENDS DEPENDS
OMPimIncGen OMPimIncGen

View File

@@ -14,20 +14,13 @@ def PimDialect : Dialect {
let cppNamespace = "::onnx_mlir::pim"; let cppNamespace = "::onnx_mlir::pim";
} }
// Base class for Pim dialect operations. This operation inherits from the
// base `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class PimOp<string mnemonic, list<Trait> traits = []> : class PimOp<string mnemonic, list<Trait> traits = []> :
Op<PimDialect, mnemonic, traits>; Op<PimDialect, mnemonic, traits>;
def PimTensor : def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
//===----------------------------------------------------------------------===// // Communication
// Communication operations
//===----------------------------------------------------------------------===//
def PimSendOp: PimOp<"send", []> { def PimSendOp: PimOp<"send", []> {
let arguments = (ins let arguments = (ins
@@ -63,9 +56,7 @@ def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
}]; }];
} }
//===----------------------------------------------------------------------===// // Core
// Core operations
//===----------------------------------------------------------------------===//
def PimCoreOp: PimOp<"core", [SingleBlock]> { def PimCoreOp: PimOp<"core", [SingleBlock]> {
@@ -81,9 +72,7 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
}]; }];
} }
//===----------------------------------------------------------------------===// // Memory
// Memory Operations
//===----------------------------------------------------------------------===//
def PimConstantOp: PimOp<"constant", []> { def PimConstantOp: PimOp<"constant", []> {
let description = [{ let description = [{
@@ -157,9 +146,36 @@ def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
}]; }];
} }
//===----------------------------------------------------------------------===// def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
// Core.Compute operations let description = [{
//===----------------------------------------------------------------------===// Copy a memory region from and to the same memory
}];
let arguments = (ins
PimTensor: $dst,
PimTensor: $src,
I32Attr: $dstOffset,
I32Attr: $srcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $dstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
}];
}
// Computation
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{ let description = [{

View File

@@ -0,0 +1,22 @@
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization
PimBufferizationPass.hpp
PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp
Common.hpp
Common.cpp
DEPENDS
PimBufferizationIncGen
LINK_LIBS PUBLIC
OMPIMCommon
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -0,0 +1,9 @@
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
return builder.getI32IntegerAttr(sizeInBytes);
}

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -1,11 +1,10 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
using namespace mlir; using namespace mlir;
using namespace bufferization; using namespace bufferization;
@@ -173,7 +172,7 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
} }
}; };
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);

View File

@@ -9,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); void registerOpBufferizationInterfaces(DialectRegistry& registry);
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -0,0 +1,19 @@
#ifndef PIM_BUFFERIZATION
#define PIM_BUFFERIZATION
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
#endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat<
(CopyOp $src, $dst),
(PimMemCopyOp $dst, $src,
ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">,
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION

View File

@@ -5,23 +5,18 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/Support/raw_os_ostream.h"
#include "Common/PIMCommon.hpp" #include "Common/PIMCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp" #include "PimBufferizationPass.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir;
namespace onnx_mlir { using namespace pim;
namespace pim {
void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation(); auto moduleOp = getOperation();
// Do One-Shot-Bufferization // One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options; bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true; options.allowUnknownOps = true;
bufferization::BufferizationState state; bufferization::BufferizationState state;
@@ -30,7 +25,19 @@ void PimBufferizationPass::runOnOperation() {
signalPassFailure(); signalPassFailure();
} }
// Remove toTensor operations MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
// Remove toTensor operations: leave memrefs instead
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) { moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer()); toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
toTensorOp.erase(); toTensorOp.erase();
@@ -63,8 +70,8 @@ void PimBufferizationPass::runOnOperation() {
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext(); MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }) bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
&& !getGlobalOp->getUsers().empty(); && all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) { if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());
@@ -73,7 +80,3 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
} }
}); });
} }
} // namespace pim
} // namespace onnx_mlir

View File

@@ -2,6 +2,8 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "Dialect/PIM/PimOps.hpp"
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
@@ -9,6 +11,8 @@ namespace onnx_mlir {
namespace pim { namespace pim {
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> { struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; } StringRef getArgument() const override { return "bufferize-pim"; }

View File

@@ -13,7 +13,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -64,7 +64,7 @@ void PimAccelerator::registerDialects(DialectRegistry& registry) const {
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerBufferizableOpInterfaceExternalModels(registry); pim::registerOpBufferizationInterfaces(registry);
} }
void PimAccelerator::registerPasses(int optLevel) const { void PimAccelerator::registerPasses(int optLevel) const {

Binary file not shown.

Binary file not shown.