1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use super::factoid::*;
use super::*;
pub fn infer_forward_concrete(
op: &dyn Op,
inputs: &Vec<&InferenceFact>,
) -> TractResult<Option<TVec<InferenceFact>>> {
let input_values: TVec<_> =
inputs.iter().filter_map(|t| t.value.concretize()).map(|v| v.clone().into()).collect();
if input_values.len() < inputs.len() {
debug!("Can't infer value: some inputs are still unknown.");
return Ok(None);
}
if op.is_stateless() {
let output_value = op.eval(input_values)?.pop().unwrap();
return Ok(Some(tvec![output_value.into()]));
}
Ok(None)
}
pub fn infer_shape_broadcasting(shapes: &[&ShapeFactoid]) -> TractResult<Option<ShapeFactoid>> {
if shapes.iter().any(|s| s.is_open()) {
debug!("Can't infer shape for broadcasting operators when some inputs have an open shape.");
return Ok(None);
}
let bound = shapes.iter().map(|s| s.rank().concretize().unwrap()).max().unwrap() as usize;
let mut output_shape: TVec<DimFact> = tvec![];
for i in 0..bound {
let mut previous: Option<TDim> = None;
let mut unknown = 0;
for shape in shapes.iter() {
let rank = shape.rank().concretize().unwrap() as usize;
let shape: TVec<DimFact> = shape.dims().cloned().collect();
if i >= rank {
continue;
}
match &shape[rank - i - 1] {
GenericFactoid::Any => unknown += 1,
GenericFactoid::Only(ref d) if d.is_one() => (),
GenericFactoid::Only(ref d) => {
if previous.is_some() && previous.as_ref() != Some(d) {
bail!(
"Invalid shape (broadcasting): {:?} is not compatible with {:?}.",
d,
previous
)
} else {
previous = Some(d.clone())
}
}
};
}
if unknown > 1 {
debug!("Can't infer shape (broadcasting): there are multiple unknown values at same index.");
return Ok(None);
} else if unknown == 1 && previous != None {
debug!("Can't infer shape (broadcasting): there are both unknown and known values at same index.");
return Ok(None);
} else if unknown == 1 && previous == None {
output_shape.push(GenericFactoid::Any);
} else if let Some(previous) = previous {
output_shape.push(GenericFactoid::Only(previous.clone()));
} else {
output_shape.push(GenericFactoid::Only(1.into()));
}
}
output_shape.reverse();
Ok(Some(ShapeFactoid::closed(output_shape)))
}
pub fn infer_forward_basic(
op: &dyn Op,
inputs: Vec<&InferenceFact>,
) -> TractResult<Option<TVec<InferenceFact>>> {
if let Some(output) = infer_forward_concrete(op, &inputs)? {
return Ok(Some(output));
}
let input_shapes: Vec<_> = inputs.iter().map(|t| &t.shape).collect();
let datum_type = inputs
.iter()
.filter_map(|i| i.datum_type.concretize())
.next()
.map(|t| typefact!(t))
.unwrap_or(typefact!(_));
let output = InferenceFact {
datum_type,
shape: infer_shape_broadcasting(&input_shapes)?.unwrap_or(shapefactoid![..]),
value: valuefact!(_),
};
Ok(Some(tvec![output]))
}
pub fn most_specific_shape<'a, I: IntoIterator<Item = &'a ShapeFactoid>>(
iter: I,
) -> TractResult<Option<&'a ShapeFactoid>> {
let mut prev_rank = None;
let mut prev_concrete = None;
let mut best = None;
for shape in iter {
if let Some(rank) = shape.rank().concretize() {
if prev_rank.is_some() && rank != prev_rank.unwrap() {
bail!("Rank mismatch between different shapes.");
} else {
prev_rank = Some(rank);
}
let concrete = shape.dims().filter(|d| d.is_concrete()).count();
if prev_concrete.is_none() || concrete > prev_concrete.unwrap() {
prev_concrete = Some(concrete);
best = Some(shape)
}
}
}
Ok(best)
}