1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use crate::internal::*;
use tract_ndarray::prelude::*;

#[derive(Debug, Clone, new, Hash)]
pub struct GatherNd {
    pub batch_dims: usize,
}

impl_dyn_hash!(GatherNd);

impl GatherNd {
    fn compute_shape<D: DimLike>(
        &self,
        data_shape: &[D],
        indices_shape: &[D],
    ) -> TractResult<TVec<D>> {
        let mut shape: TVec<D> = indices_shape.into();
        let n = shape.pop().unwrap().to_usize()?;
        shape.extend(data_shape[n + self.batch_dims as usize..].iter().cloned());
        Ok(shape)
    }

    unsafe fn eval_t<T: Datum>(
        &self,
        output: &mut Tensor,
        data: &Tensor,
        indices: &ArrayViewD<i32>,
    ) {
        let batch_dims = self.batch_dims as usize;
        assert_eq!(output.shape()[..batch_dims], data.shape()[..batch_dims]);
        assert_eq!(output.shape()[..batch_dims], indices.shape()[..batch_dims]);
        let batch_size = data.shape().iter().take(batch_dims).product();
        let n = indices.shape()[indices.ndim() - 1];

        let remaining = indices.shape().iter().skip(batch_dims).rev().skip(1).product();
        let indices_shape_op = tvec!(batch_size, remaining, n);
        let reshaped_indices: ArrayViewD<i32> =
            indices.view().into_shape(&*indices_shape_op).unwrap();

        let mut data_shape_op: TVec<usize> =
            data.shape().iter().skip(batch_dims).copied().collect();
        data_shape_op.insert(0, batch_size);
        let reshaped_data =
            data.to_array_view_unchecked::<T>().into_shape(&*data_shape_op).unwrap();

        let mut output_shape_op: TVec<usize> =
            data.shape().iter().skip(n + batch_dims).copied().collect();
        output_shape_op.insert(0, batch_size * remaining);
        let mut output =
            output.to_array_view_mut_unchecked::<T>().into_shape(&*output_shape_op).unwrap();

        for b in 0..batch_size {
            let mut i = reshaped_data.view();
            i.index_axis_inplace(Axis(0), b);
            let mut coords = reshaped_indices.view();
            coords.index_axis_inplace(Axis(0), b);

            for ix in 0..remaining {
                let mut coords = coords.view();
                coords.index_axis_inplace(Axis(0), ix);

                let mut i = i.view();
                for x in coords {
                    i.index_axis_inplace(Axis(0), *x as usize);
                }

                let mut o = output.view_mut();
                o.index_axis_inplace(Axis(0), b * remaining + ix);
                o.assign(&i);
            }
        }
    }
}

impl Op for GatherNd {
    fn name(&self) -> Cow<str> {
        "GatherNd".into()
    }

    op_core!();
    op_as_typed_op!();
}

impl EvalOp for GatherNd {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let (data, indices) = args_2!(inputs);
        let shape = self.compute_shape(&data.shape(), &indices.shape())?;
        let indices = indices.cast_to::<i32>()?;
        let indices = indices.to_array_view::<i32>()?;
        unsafe {
            let mut output = Tensor::uninitialized_dt(data.datum_type(), &*shape)?;
            dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
                self,
                &mut output,
                &data,
                &indices
            ));
            Ok(tvec!(output.into_arc_tensor()))
        }
    }
}

impl TypedOp for GatherNd {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let shape = self.compute_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?;
        Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &shape)))
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let Some(indices) = &model.outlet_fact(node.inputs[1])?.konst {
            if indices.rank() == 2 && indices.shape()[0] == 1 {
                let mut patch = TypedModelPatch::default();
                let mut wire = patch.tap_model(model, node.inputs[0])?;
                for (axis, &i) in indices.cast_to::<i32>()?.as_slice::<i32>()?.iter().enumerate() {
                    wire = patch.wire_node(
                        format!("{}-slice-axis-{}", node.name, axis),
                        crate::ops::array::Slice::new(axis, i as usize, (i + 1) as usize),
                        &[wire],
                    )?[0];
                }
                for i in (0..indices.shape()[1]).rev() {
                    wire = patch.wire_node(
                        format!("{}-remove_axis_{}", node.name, i),
                        crate::ops::change_axes::AxisOp::Rm(i),
                        &[wire],
                    )?[0];
                }
                wire = patch.wire_node(
                    format!("{}-add_axis", node.name),
                    crate::ops::change_axes::AxisOp::Add(0),
                    &[wire],
                )?[0];
                patch.shunt_outside(model, node.id.into(), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }
}