1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use super::*;
use crate::internal::*;
use crate::ops::Op;
use tract_itertools::Itertools;
use std::fmt;
use std::fmt::{Debug, Display};
use std::hash::Hash;
#[derive(Debug, Clone, Educe)]
#[educe(Hash)]
pub struct Node<F: Fact + Hash, O: Hash> {
pub id: usize,
pub name: String,
pub inputs: Vec<OutletId>,
#[cfg_attr(feature = "serialize", serde(skip))]
pub op: O,
pub outputs: TVec<Outlet<F>>,
}
impl<F: Fact + Hash, O: Hash + std::fmt::Display> fmt::Display for Node<F, O> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "#{} \"{}\" {}", self.id, self.name, self.op)
}
}
impl<F, NodeOp> Node<F, NodeOp>
where
F: Fact + Hash,
NodeOp: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + AsMut<dyn Op> + Hash,
{
pub fn op(&self) -> &dyn Op {
self.op.as_ref()
}
pub fn op_as<O: Op>(&self) -> Option<&O> {
self.op().downcast_ref::<O>()
}
pub fn op_as_mut<O: Op>(&mut self) -> Option<&mut O> {
self.op.as_mut().downcast_mut::<O>()
}
pub fn op_is<O: Op>(&self) -> bool {
self.op_as::<O>().is_some()
}
pub fn same_as(&self, other: &Node<F, NodeOp>) -> bool {
self.inputs == other.inputs && self.op().same_as(other.op())
}
}
#[derive(Clone, Default, Educe)]
#[educe(Hash)]
pub struct Outlet<F: Fact + Hash> {
pub fact: F,
pub successors: TVec<InletId>,
}
impl<F: Fact + Hash> fmt::Debug for Outlet<F> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"{:?} {}",
self.fact,
self.successors.iter().map(|o| format!("{:?}", o)).join(" ")
)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, new)]
pub struct OutletId {
pub node: usize,
pub slot: usize,
}
impl fmt::Debug for OutletId {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}/{}>", self.node, self.slot)
}
}
impl From<usize> for OutletId {
fn from(node: usize) -> OutletId {
OutletId::new(node, 0)
}
}
impl From<(usize, usize)> for OutletId {
fn from(pair: (usize, usize)) -> OutletId {
OutletId::new(pair.0, pair.1)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, new, Ord, PartialOrd)]
pub struct InletId {
pub node: usize,
pub slot: usize,
}
impl fmt::Debug for InletId {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, ">{}/{}", self.node, self.slot)
}
}