1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
//! Bindings for exposing the functionality of `librunecoral`
//!
//! # Example
//!
//! ```rust,no_run
//! # fn load_model() -> &'static [u8] { todo!() }
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! use hotg_runecoral::{Tensor, TensorMut, InferenceContext, AccelerationBackend};
//!
//! // load the model
//! let model: &[u8] = load_model();
//!
//! // set aside some arrays for our inputs and outputs
//! let input = [0.0_f32];
//! let mut output = [0.0_f32];
//!
//! // And create tensors which point to them
//! let input_tensor = Tensor::from_slice(&input, &[1]);
//! let output_tensor = TensorMut::from_slice(&mut output, &[1]);
//!
//! // load our inference backend
//! let mut ctx = InferenceContext::create_context(
//!     "application/tflite-context",
//!     model,
//!     AccelerationBackend::NONE,
//! )?;
//!
//! // Now we can run inference
//! ctx.infer(&[input_tensor], &mut [output_tensor])?;
//!
//! // and look at the results
//! println!("{:?} => {:?}", input, output);
//! # Ok(())
//! # }
//! ```

#![deny(
    elided_lifetimes_in_paths,
    missing_debug_implementations,
    unreachable_pub,
    unused_crate_dependencies
)]

mod context;
pub mod ffi;
mod tensors;

pub use crate::{
    context::{AccelerationBackend, InferenceContext, LoadError},
    tensors::{ElementType, Tensor, TensorDescriptor, TensorElement, TensorMut},
};

use std::ffi::{CStr, NulError};

#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum Error {
    #[error("Invalid string")]
    InvalidString(#[from] NulError),
    #[error("Unable to load the model")]
    Load(#[from] LoadError),
}

/// The mimetype used by this crate to represent TensorFlow Lite models.
pub fn mimetype() -> &'static str {
    unsafe {
        CStr::from_ptr(ffi::RUNE_CORAL_MIME_TYPE__TFLITE)
            .to_str()
            .unwrap()
    }
}