1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use crate::internal::*;
use std::collections::HashSet;
use std::fmt::Debug;
use tract_itertools::Itertools;

pub mod change_axes;
mod op_optim;
mod prop_const;
mod push_split_down;

use self::change_axes::ChangeAxes;
use self::prop_const::PropConst;
use self::push_split_down::PushSplitDown;
use op_optim::OpOptim;

pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
    fn reset(&mut self) -> TractResult<()>;
    fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>>;
}

dyn_clone::clone_trait_object!(TypedPass);

pub struct Optimizer {
    passes: Vec<Box<dyn TypedPass>>,
    steps: Option<usize>,
}

impl Optimizer {
    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
        Optimizer { passes, steps: None }
    }

    pub fn stopping_at(self, steps: usize) -> Optimizer {
        Optimizer { steps: Some(steps), ..self }
    }

    pub fn declutter() -> Optimizer {
        Optimizer::passes(vec![
            Box::new(OpOptim("declutter", TypedOp::declutter, 0)),
            Box::new(PropConst),
            Box::new(PushSplitDown),
            Box::new(ChangeAxes),
        ])
    }

    pub fn codegen() -> Optimizer {
        Optimizer::passes(vec![
            Box::new(OpOptim("codegen", TypedOp::codegen, 0)),
            Box::new(OpOptim("declutter", TypedOp::declutter, 0)),
            Box::new(PropConst),
            Box::new(PushSplitDown),
            Box::new(OpOptim("fuse", TypedOp::fuse, 0)),
        ])
    }

    pub fn optimize(&self, model: &TypedModel) -> TractResult<TypedModel> {
        #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
        {
            model.check_consistent_facts()?;
        }
        let mut seen = HashSet::new();
        let mut model = model.compact()?;
        let mut counter = 0;
        for i in 0.. {
            let counter_and_model = self.run_all_passes(i, counter, model, &mut seen)?;
            if counter_and_model.0 == counter {
                return Ok(counter_and_model.1);
            }
            counter = counter_and_model.0;
            model = counter_and_model.1.compact()?;
            model = model.compact()?;
        }
        unreachable!()
    }

    pub fn run_all_passes(
        &self,
        i: usize,
        mut counter: usize,
        mut model: TypedModel,
        seen: &mut HashSet<String>,
    ) -> TractResult<(usize, TypedModel)> {
        let mut passes = self.passes.clone();
        for p in passes.iter_mut() {
            let counter_and_model = self.run_one_pass_outer(i, p.as_mut(), counter, model, seen)?;
            counter = counter_and_model.0;
            model = counter_and_model.1.compact()?;
        }
        Ok((counter, model))
    }

    pub fn run_one_pass_outer(
        &self,
        i: usize,
        p: &mut dyn TypedPass,
        mut counter: usize,
        mut model: TypedModel,
        seen: &mut HashSet<String>,
    ) -> TractResult<(usize, TypedModel)> {
        loop {
            let counter_and_model = self.run_one_pass_inner(i, p, counter, model, seen)?;
            if counter_and_model.0 == counter {
                return Ok(counter_and_model);
            }
            counter = counter_and_model.0;
            model = counter_and_model.1.compact()?;
        }
    }

    pub fn run_one_pass_inner(
        &self,
        i: usize,
        p: &mut dyn TypedPass,
        mut counter: usize,
        mut model: TypedModel,
        seen: &mut HashSet<String>,
    ) -> TractResult<(usize, TypedModel)> {
        p.reset()?;
        while let Some(mut patch) = p.next(&model)? {
            if let Some(steps) = self.steps {
                if counter >= steps {
                    return Ok((counter, model));
                }
            }
            patch.push_context(format!("{:?}/{}", p, i));
            #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
            {
                patch.model.check_consistent_facts()?;
                model.check_consistent_facts()?;
                patch.model.invariants()?;
                model.invariants()?;
            }
            if let Some(watchdog) = patch.dont_apply_twice.take() {
                if seen.contains(&watchdog) {
                    debug!("Loop detected: {} seen before", watchdog);
                    continue;
                } else {
                    seen.insert(watchdog);
                }
            }
            debug!("applying patch #{}: {}", counter, patch.context.iter().rev().join(" >> "),);
            patch.apply(&mut model)?;
            counter += 1;
        }
        #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
        {
            model.check_edges()?;
            model
                .check_consistent_facts()
                .with_context(|| format!("after declutter pass {:?}", p))?
        }
        Ok((counter, model))
    }
}