1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use crate::internal::*;
#[derive(Debug, Clone, new)]
pub struct SourceState(pub usize);
impl OpState for SourceState {
fn eval(
&mut self,
session: &mut SessionState,
_op: &dyn Op,
_inputs: TVec<Arc<Tensor>>,
) -> TractResult<TVec<Arc<Tensor>>> {
Ok(tvec!(session.inputs[&self.0].clone()))
}
}
#[derive(Debug, Clone, new, Hash)]
pub struct TypedSource {
pub fact: TypedFact,
}
impl_dyn_hash!(TypedSource);
impl Op for TypedSource {
fn name(&self) -> Cow<str> {
"Source".into()
}
op_core_lir_mir!();
op_as_typed_op!();
}
impl EvalOp for TypedSource {
fn is_stateless(&self) -> bool {
false
}
fn state(
&self,
_session: &mut SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(SourceState(node_id))))
}
}
impl TypedOp for TypedSource {
fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(self.fact.clone()))
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
_io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
let mut fact = self.fact.clone();
change.change_shape(&mut fact.shape, false)?;
Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(TypedSource::new(fact))),
change,
)))
}
fn concretize_dims(
&self,
_source: &TypedModel,
node: &TypedNode,
target: &mut TypedModel,
_mapping: &HashMap<OutletId, OutletId>,
values: &SymbolValues,
) -> TractResult<TVec<OutletId>> {
let shape: TVec<_> = self.fact.shape.iter().map(|d| d.eval(values)).collect();
target.wire_node(
&node.name,
Self { fact: TypedFact::dt_shape(self.fact.datum_type, &*shape) },
&[],
)
}
as_op!();
}