1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use std::{io::Write, path::PathBuf};
use anyhow::{Context, Error};
use hotg_runecoral::{
mimetype, AccelerationBackend, InferenceContext, TensorDescriptor,
};
use strum::VariantNames;
use crate::Format;
#[derive(Debug, Clone, PartialEq, structopt::StructOpt)]
pub struct ModelInfo {
#[structopt(
help = "The TensorFlow Lite model to inspect",
parse(from_os_str)
)]
file: PathBuf,
#[structopt(
short,
long,
help = "The format to print output in",
default_value = "text",
possible_values = Format::VARIANTS,
parse(try_from_str)
)]
format: Format,
}
impl ModelInfo {
pub fn execute(self) -> Result<(), Error> {
let raw = std::fs::read(&self.file).with_context(|| {
format!("Unable to read \"{}\"", &self.file.display())
})?;
let ctx = InferenceContext::create_context(
mimetype(),
&raw,
AccelerationBackend::NONE,
)
.context("Unable to an inference context")?;
match self.format {
Format::Text => print_info(&ctx),
Format::Json => {
let mut stdout = std::io::stdout();
serde_json::to_writer_pretty(
stdout.lock(),
&ModelDescription {
inputs: ctx
.inputs()
.map(|x| TensorInfo::from(&x))
.collect(),
outputs: ctx
.outputs()
.map(|x| TensorInfo::from(&x))
.collect(),
ops: ctx.opcount() as usize,
},
)
.context("Unable to print to stdout")?;
writeln!(stdout)?;
},
}
Ok(())
}
}
fn print_info(ctx: &InferenceContext) {
println!("Ops: {}", ctx.opcount());
println!("Inputs:");
for input in ctx.inputs() {
println!("\t{}", input);
}
println!("Outputs:");
for output in ctx.outputs() {
println!("\t{}", output);
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
struct ModelDescription {
inputs: Vec<TensorInfo>,
outputs: Vec<TensorInfo>,
ops: usize,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
struct TensorInfo {
name: String,
element_kind: String,
dims: Vec<usize>,
}
impl From<&TensorDescriptor<'_>> for TensorInfo {
fn from(t: &TensorDescriptor<'_>) -> TensorInfo {
TensorInfo {
name: t.name.to_str().unwrap().to_string(),
element_kind: t.element_type.to_string(),
dims: t.shape.iter().map(|&x| x as usize).collect(),
}
}
}