1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
use super::TypedPass;
use crate::internal::*;
use crate::model::*;
use crate::TractResult;
use crate::ops::change_axes::*;
#[derive(Clone, Debug)]
pub struct ChangeAxes;
impl TypedPass for ChangeAxes {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
let outlet = suggestion.0.as_outlet(&model.node(n));
let change = AxisChange { outlet, op: suggestion.1 };
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
Ok(None)
}
}