1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#[macro_export]
macro_rules! wrap {
($($x:expr),*) => ({
vec![$( $crate::infer::rules::expr::IntoExp::bex($x) ),*]
});
($($x:expr,)*) => (wrap![$($x),*]);
}
use crate::infer::*;
mod cache;
pub mod expr;
mod path;
mod proxies;
mod solver;
pub use self::proxies::*;
pub use self::solver::Solver;
pub type InferenceResult = TractResult<()>;
pub trait InferenceRulesOp {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
solver: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult;
fn as_op(&self) -> &dyn Op;
fn as_op_mut(&mut self) -> &mut dyn Op;
#[allow(unused_variables)]
fn to_typed(
&self,
source: &InferenceModel,
node: &InferenceNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
bail!("Node {} can not be typed", node)
}
fn nboutputs(&self) -> TractResult<usize> {
Ok(1)
}
#[allow(unused_variables)]
fn incorporate(
&self,
model: &InferenceModel,
node: &InferenceNode,
) -> TractResult<Option<InferenceModelPatch>> {
Ok(None)
}
}
impl<O: InferenceRulesOp + Op> InferenceOp for O {
fn infer_facts(
&mut self,
inputs: TVec<&InferenceFact>,
outputs: TVec<&InferenceFact>,
observed: TVec<&InferenceFact>,
) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
let inputs_proxy: TVec<TensorProxy> =
(0..inputs.len()).map(|ix| TensorProxy::new(tvec!(0, ix as isize).into())).collect();
let outputs_proxy: TVec<TensorProxy> =
(0..outputs.len()).map(|ix| TensorProxy::new(tvec!(1, ix as isize).into())).collect();
trace!("Building rules for {:?}", self);
let mut solver = Solver::default();
self.rules(&mut solver, &inputs_proxy, &outputs_proxy)?;
trace!("Applying rules for {:?}", self);
let (input, output) = solver.infer_facts((inputs, outputs))?;
trace!("Solver done");
Ok((input, output, observed.into_iter().cloned().collect()))
}
fn nboutputs(&self) -> TractResult<usize> {
self.nboutputs()
}
fn observe_outlets(
&self,
_model: &InferenceModel,
_node: &InferenceNode,
) -> TractResult<Vec<OutletId>> {
Ok(vec![])
}
fn as_op(&self) -> &dyn Op {
self.as_op()
}
fn as_op_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
fn to_typed(
&self,
source: &InferenceModel,
node: &InferenceNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
self.to_typed(source, node, target, mapping)
}
fn incorporate(
&self,
model: &InferenceModel,
node: &InferenceNode,
) -> TractResult<Option<InferenceModelPatch>> {
self.incorporate(model, node)
}
}