1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//! A fluent interface for the analyser.
//!
//! This interface provides proxies for the different properties of tensors.
//! This allows inference rules to be stated in a clear, declarative fashion
//! inside the `rules` method of each operator.
//!
//! Take these rules for instance:
//! ```text
//! solver.equals(inputs.len(), 2);
//! solver.equals(inputs[0].datum_type, outputs[0].datum_type);
//! ```
//! Here, `inputs.len`, `inputs[0].datum_type` and `outputs[0].datum_type` don't
//! actually hold the values of the length and datum_types, but instead act as
//! declarative placeholders for these values.

#[macro_export]
macro_rules! wrap {
    ($($x:expr),*) => ({
        vec![$( $crate::infer::rules::expr::IntoExp::bex($x) ),*]
    });

    ($($x:expr,)*) => (wrap![$($x),*]);
}

use crate::infer::*;

mod cache;
pub mod expr;
mod path;
mod proxies;
mod solver;

pub use self::proxies::*;
pub use self::solver::Solver;

pub type InferenceResult = TractResult<()>;

pub trait InferenceRulesOp {
    /// Registers the inference rules of the operator.
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        solver: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult;

    fn as_op(&self) -> &dyn Op;
    fn as_op_mut(&mut self) -> &mut dyn Op;

    #[allow(unused_variables)]
    fn to_typed(
        &self,
        source: &InferenceModel,
        node: &InferenceNode,
        target: &mut TypedModel,
        mapping: &HashMap<OutletId, OutletId>,
    ) -> TractResult<TVec<OutletId>> {
        bail!("Node {} can not be typed", node)
    }

    fn nboutputs(&self) -> TractResult<usize> {
        Ok(1)
    }

    #[allow(unused_variables)]
    fn incorporate(
        &self,
        model: &InferenceModel,
        node: &InferenceNode,
    ) -> TractResult<Option<InferenceModelPatch>> {
        Ok(None)
    }
}

impl<O: InferenceRulesOp + Op> InferenceOp for O {
    fn infer_facts(
        &mut self,
        inputs: TVec<&InferenceFact>,
        outputs: TVec<&InferenceFact>,
        observed: TVec<&InferenceFact>,
    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
        let inputs_proxy: TVec<TensorProxy> =
            (0..inputs.len()).map(|ix| TensorProxy::new(tvec!(0, ix as isize).into())).collect();
        let outputs_proxy: TVec<TensorProxy> =
            (0..outputs.len()).map(|ix| TensorProxy::new(tvec!(1, ix as isize).into())).collect();

        trace!("Building rules for {:?}", self);
        let mut solver = Solver::default();
        self.rules(&mut solver, &inputs_proxy, &outputs_proxy)?;
        trace!("Applying rules for {:?}", self);
        let (input, output) = solver.infer_facts((inputs, outputs))?;
        trace!("Solver done");
        Ok((input, output, observed.into_iter().cloned().collect()))
    }

    fn nboutputs(&self) -> TractResult<usize> {
        self.nboutputs()
    }

    fn observe_outlets(
        &self,
        _model: &InferenceModel,
        _node: &InferenceNode,
    ) -> TractResult<Vec<OutletId>> {
        Ok(vec![])
    }

    fn as_op(&self) -> &dyn Op {
        self.as_op()
    }

    fn as_op_mut(&mut self) -> &mut dyn Op {
        self.as_op_mut()
    }

    fn to_typed(
        &self,
        source: &InferenceModel,
        node: &InferenceNode,
        target: &mut TypedModel,
        mapping: &HashMap<OutletId, OutletId>,
    ) -> TractResult<TVec<OutletId>> {
        self.to_typed(source, node, target, mapping)
    }

    fn incorporate(
        &self,
        model: &InferenceModel,
        node: &InferenceNode,
    ) -> TractResult<Option<InferenceModelPatch>> {
        self.incorporate(model, node)
    }
}