1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use std::fmt::Debug;
use std::marker::PhantomData;
use tract_data::anyhow;

use tract_data::prelude::Tensor;

pub trait ElementWise<T>: Send + Sync + Debug + dyn_clone::DynClone
where
    T: Copy + Debug + PartialEq + Send + Sync,
{
    fn run(&self, vec: &mut [T]) -> anyhow::Result<()>;
}

dyn_clone::clone_trait_object!(<T> ElementWise<T> where T: Copy);

#[derive(Debug, Clone, new)]
pub struct ElementWiseImpl<K, T>
where
    T: Copy + Debug + PartialEq + Send + Sync,
    K: ElementWiseKer<T> + Clone,
{
    phantom: PhantomData<(K, T)>,
}

impl<K, T> ElementWise<T> for ElementWiseImpl<K, T>
where
    T: crate::Datum + Copy + Debug + PartialEq + Send + Sync,
    K: ElementWiseKer<T> + Clone,
{
    fn run(&self, vec: &mut [T]) -> anyhow::Result<()> {
        if vec.len() == 0 {
            return Ok(());
        }
        unsafe {
            let mut tmp_buffer =
                Tensor::uninitialized_aligned::<T>(&[K::nr()], K::alignment_bytes()).unwrap();
            let mut tmp = tmp_buffer.as_slice_mut_unchecked::<T>();
            let mut compute_via_temp_buffer = |slice: &mut [T]| {
                tmp[..slice.len()].copy_from_slice(slice);
                K::run(&mut tmp);
                slice.copy_from_slice(&tmp[..slice.len()])
            };
            let prefix_len = vec.as_ptr().align_offset(K::alignment_bytes()).min(vec.len());
            if prefix_len > 0 {
                compute_via_temp_buffer(&mut vec[..prefix_len]);
            }
            let aligned_len = (vec.len() - prefix_len) / K::nr() * K::nr();
            if aligned_len > 0 {
                K::run(&mut vec[prefix_len..][..aligned_len]);
            }
            if prefix_len + aligned_len < vec.len() {
                compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..]);
            }
        }
        Ok(())
    }
}

pub trait ElementWiseKer<T>: Send + Sync + Debug + dyn_clone::DynClone + Clone
where
    T: Copy + Debug + PartialEq + Send + Sync,
{
    fn name() -> &'static str;
    fn alignment_bytes() -> usize;
    fn alignment_items() -> usize;
    fn nr() -> usize;
    fn run(vec: &mut [T]);
}