1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use crate::internal::*;

use crate::ops::cnn::{PaddingSpec, Patch, PatchSpec};
use crate::ops::nn::{BaseDataShape, DataFormat, DataShape, SymDataShape};

#[derive(Debug, Clone, new, Default, Hash, PartialEq)]
pub struct PoolSpec {
    pub data_format: DataFormat,
    pub kernel_shape: TVec<usize>,
    pub padding: PaddingSpec,
    pub dilations: Option<TVec<usize>>,
    pub strides: Option<TVec<usize>>,
    pub output_channel_override: Option<usize>,
}

impl PoolSpec {
    pub fn info(&self) -> Vec<String> {
        vec![
            format!("Data format: {:?}", self.data_format),
            format!(
                "Kernel shape:{:?} (strides:{:?}, padding:{:?}, dilations:{:?})",
                self.kernel_shape, self.strides, self.padding, self.dilations,
            ),
        ]
    }

    pub fn rank(&self) -> usize {
        self.kernel_shape.len()
    }

    pub fn dilation(&self, geo_axis: usize) -> usize {
        self.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1)
    }

    pub fn dilations(&self) -> Cow<[usize]> {
        self.dilations
            .as_deref()
            .map_or_else(|| vec![1; self.kernel_shape.len()].into(), |d| d.into())
    }

    pub fn stride(&self, geo_axis: usize) -> usize {
        self.strides.as_ref().map(|s| s[geo_axis]).unwrap_or(1)
    }

    pub fn strides(&self) -> Cow<[usize]> {
        self.strides
            .as_deref()
            .map_or_else(|| vec![1; self.kernel_shape.len()].into(), |d| d.into())
    }

    pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TractResult<BaseDataShape<D, TVec<D>>> {
        let ishape: BaseDataShape<D, TVec<D>> = self.data_format.shape(input.into())?;
        let computed = self.padding.compute(
            ishape.hw_dims(),
            &*self.kernel_shape,
            &self.dilations(),
            &self.strides(),
        );
        let spatial_dims = computed.into_iter().map(|d| d.convoluted).collect::<TVec<D>>();
        let oshape = self.data_format.from_n_c_hw(
            ishape.n().cloned().unwrap_or(1.into()),
            self.output_channel_override.map(|i| i.into()).unwrap_or(ishape.c().clone()),
            spatial_dims,
        )?;
        Ok(oshape)
    }

    pub fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let oshape = self.output_shape(&inputs[0].shape)?;
        Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, oshape.shape)))
    }

    pub fn dispose_n_axis(&self) -> PoolSpec {
        PoolSpec { data_format: self.data_format.dispose_n_axis(), ..self.clone() }
    }

    pub fn compute_geo(&self, input_full_shape: &[TDim]) -> TractResult<PoolGeometry> {
        let output_shape = self.output_shape(input_full_shape)?;
        let input_shape: SymDataShape = self.data_format.shape(input_full_shape.into())?;
        Ok(PoolGeometry::Symbolic(SymbolicPoolGeometry {
            pool_spec: self.clone(),
            input_shape,
            output_shape,
        }))
    }
}

pub type PoolGeometry = super::GeometryBound<SymbolicPoolGeometry, ConcretePoolGeometry>;

#[derive(Debug, Clone, Hash, PartialEq)]
pub struct SymbolicPoolGeometry {
    pub pool_spec: PoolSpec,
    pub input_shape: SymDataShape,
    pub output_shape: SymDataShape,
}

#[derive(Debug, Clone, Hash, PartialEq)]
pub struct ConcretePoolGeometry {
    pub input_shape: DataShape,
    pub patch: Patch,
    pub output_shape: DataShape,
}

impl super::ResolveTo<ConcretePoolGeometry> for SymbolicPoolGeometry {
    type Param = [usize];
    fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcretePoolGeometry> {
        let input_shape = self.pool_spec.data_format.shape(input_full_shape.into())?;
        let output_inner_stride = match self.pool_spec.data_format {
            DataFormat::NCHW | DataFormat::CHW => 1,
            DataFormat::NHWC | DataFormat::HWC => {
                self.pool_spec.output_channel_override.clone().unwrap_or(*input_shape.c())
            }
        };
        let mut spec = PatchSpec::for_full_shape(self.pool_spec.data_format, input_full_shape)?
            .with_output_inner_stride(output_inner_stride)
            .with_kernel_shape(self.pool_spec.kernel_shape.clone())
            .with_padding(self.pool_spec.padding.clone());
        if let Some(strides) = self.pool_spec.strides.clone() {
            spec = spec.with_strides(strides);
        }
        if let Some(dilations) = self.pool_spec.dilations.clone() {
            spec = spec.with_dilations(dilations);
        }
        let patch = spec.into_patch();
        let output_shape = input_shape.fmt.from_n_c_hw(
            *input_shape.n().unwrap_or(&1),
            self.pool_spec.output_channel_override.unwrap_or(*input_shape.c()),
            &*patch.output_shape,
        )?;
        Ok(ConcretePoolGeometry { input_shape, patch, output_shape })
    }
}