add softmax support to pim-simulator
This commit is contained in:
@@ -52,6 +52,7 @@ static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
add_name_simd!(hash, vrelu);
|
||||
add_name_simd!(hash, vtanh);
|
||||
add_name_simd!(hash, vsigm);
|
||||
add_name_simd!(hash, vsoftmax);
|
||||
add_name!(hash, vmv);
|
||||
add_name!(hash, vrsu);
|
||||
add_name!(hash, vrsl);
|
||||
@@ -177,6 +178,7 @@ static SIMD: LazyLock<HashMap<usize, HashMap<(usize, usize), InstructionType>>>
|
||||
add_simd_to_map!(storage, vrelu);
|
||||
add_simd_to_map!(storage, vtanh);
|
||||
add_simd_to_map!(storage, vsigm);
|
||||
add_simd_to_map!(storage, vsoftmax);
|
||||
add_simd_to_map!(storage, mvmul);
|
||||
storage
|
||||
});
|
||||
@@ -626,6 +628,46 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vsoftmax::<F,T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
let r1_val = core.register(r1);
|
||||
let rd_val = core.register(rd);
|
||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||
let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
|
||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||
let load1 = loads[0];
|
||||
ensure!(!load1.is_empty(), "vsoftmax does not support empty vectors");
|
||||
let max_val = load1
|
||||
.iter()
|
||||
.copied()
|
||||
.reduce(|a, b| if a > b { a } else { b })
|
||||
.unwrap();
|
||||
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
|
||||
let sum = exp_values
|
||||
.iter()
|
||||
.copied()
|
||||
.reduce(|a, b| a + b)
|
||||
.unwrap();
|
||||
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
|
||||
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
|
||||
let res_up: Cow<[T]> = res.as_slice().up();
|
||||
core.execute_store(rd_val, res_up.as_ref());
|
||||
TRACER.lock().unwrap().post_vsoftmax::<F,T>(cores, data);
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ static SIMD: LazyLock<HashMap<String, FunctorType>> = LazyLock::new(|| {
|
||||
add_to_json_map!(storage, vrelu);
|
||||
add_to_json_map!(storage, vtanh);
|
||||
add_to_json_map!(storage, vsigm);
|
||||
add_to_json_map!(storage, vsoftmax);
|
||||
add_to_json_map!(storage, vmv);
|
||||
add_to_json_map!(storage, vrsu);
|
||||
add_to_json_map!(storage, vrsl);
|
||||
@@ -417,6 +418,27 @@ fn json_to_vsigm(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn json_to_vsoftmax(
|
||||
inst_builder: &mut InstructionsBuilder,
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vsoftmax", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn json_to_vmv(
|
||||
inst_builder: &mut InstructionsBuilder,
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
|
||||
@@ -67,6 +67,22 @@ impl HasSigm for f64 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HasExp {
|
||||
fn exp(self) -> Self;
|
||||
}
|
||||
|
||||
impl HasExp for f32 {
|
||||
fn exp(self) -> Self {
|
||||
self.exp()
|
||||
}
|
||||
}
|
||||
|
||||
impl HasExp for f64 {
|
||||
fn exp(self) -> Self {
|
||||
self.exp()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
pub trait TryToUsize: TryInto<usize, Error = Self::TryError>
|
||||
@@ -112,6 +128,7 @@ pub trait UpcastDestTraits<T>:
|
||||
+ PartialOrd<T>
|
||||
+ HasTanh
|
||||
+ HasSigm
|
||||
+ HasExp
|
||||
+ FromUsize
|
||||
{
|
||||
}
|
||||
|
||||
@@ -248,6 +248,22 @@ impl Trace {
|
||||
{
|
||||
}
|
||||
|
||||
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
}
|
||||
|
||||
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
/////Communication/synchronization Instructions/////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -956,6 +956,35 @@ impl Trace {
|
||||
// Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let file: &mut File = self
|
||||
.out_files
|
||||
.get_mut(core_indx as usize)
|
||||
.expect("File at index not found");
|
||||
writeln!(file, "\t\tVSOFTMAX\t\t");
|
||||
}
|
||||
|
||||
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let file: &mut File = self
|
||||
.out_files
|
||||
.get_mut(core_indx as usize)
|
||||
.expect("File at index not found");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
/////Communication/synchronization Instructions/////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user