1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
use super::TypedPass;
use crate::internal::*;
use crate::model::*;
use crate::TractResult;

use crate::ops::change_axes::*;

#[derive(Clone, Debug)]
pub struct ChangeAxes;

impl TypedPass for ChangeAxes {
    fn reset(&mut self) -> TractResult<()> {
        Ok(())
    }
    fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
        let mut interfaces = model.output_outlets()?.to_vec();
        interfaces.extend(model.input_outlets()?.iter());
        for n in model.eval_order()? {
            for suggestion in model.node(n).op.suggested_axis_changes()? {
                let outlet = suggestion.0.as_outlet(&model.node(n));
                let change = AxisChange { outlet, op: suggestion.1 };
                if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
                    .with_context(|| {
                        format!("Making patch for {:?} from {}", change, model.node(n))
                    })?
                {
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }
}