mnn_sys/
lib.rs

1use std::ffi::CStr;
2mod tracing;
3
4pub mod cpp {
5    #![allow(non_upper_case_globals)]
6    #![allow(non_camel_case_types)]
7    #![allow(non_snake_case)]
8    include!(concat!(env!("OUT_DIR"), "/mnn_cpp.rs"));
9}
10mod sys {
11    #![allow(non_upper_case_globals)]
12    #![allow(non_camel_case_types)]
13    #![allow(non_snake_case)]
14    #![allow(clippy::manual_c_str_literals)]
15    #![allow(clippy::suspicious_doc_comments)]
16    include!(concat!(env!("OUT_DIR"), "/mnn_c.rs"));
17}
18pub use sys::*;
19impl DimensionType {
20    pub const NHWC: Self = Self::TENSORFLOW;
21    pub const NCHW: Self = Self::CAFFE;
22    pub const NC4HW4: Self = Self::CAFFE_C4;
23}
24impl halide_type_t {
25    unsafe fn new(code: halide_type_code_t, bits: u8, lanes: u16) -> Self {
26        Self { code, bits, lanes }
27    }
28}
29
30pub fn halide_type_of<T: HalideType>() -> halide_type_t {
31    T::halide_type_of()
32}
33
34pub trait HalideType: seal::Sealed {
35    fn halide_type_of() -> halide_type_t;
36}
37mod seal {
38    pub trait Sealed {}
39}
40
41macro_rules! halide_types {
42    ($($t:ty => $ht:expr),*) => {
43        $(
44            impl seal::Sealed for $t {}
45            impl HalideType for $t {
46                fn halide_type_of() -> halide_type_t {
47                    unsafe {
48                        $ht
49                    }
50                }
51            }
52        )*
53    };
54}
55
56halide_types! {
57    f32 =>  halide_type_t::new(halide_type_code_t::halide_type_float, 32, 1),
58    f64 =>  halide_type_t::new(halide_type_code_t::halide_type_float, 64, 1),
59    bool => halide_type_t::new(halide_type_code_t::halide_type_uint, 1, 1),
60    u8 =>   halide_type_t::new(halide_type_code_t::halide_type_uint, 8,1),
61    u16 =>  halide_type_t::new(halide_type_code_t::halide_type_uint, 16,1),
62    u32 =>  halide_type_t::new(halide_type_code_t::halide_type_uint, 32,1),
63    u64 =>  halide_type_t::new(halide_type_code_t::halide_type_uint, 64,1),
64    i8 =>   halide_type_t::new(halide_type_code_t::halide_type_int, 8,1),
65    i16 =>  halide_type_t::new(halide_type_code_t::halide_type_int, 16,1),
66    i32 =>  halide_type_t::new(halide_type_code_t::halide_type_int, 32,1),
67    i64 =>  halide_type_t::new(halide_type_code_t::halide_type_int, 64,1)
68}
69
70impl Drop for CString {
71    fn drop(&mut self) {
72        unsafe { destroyCString(self.as_ptr_mut()) }
73    }
74}
75
76impl CString {
77    pub fn as_ptr(&self) -> *const CString {
78        core::ptr::addr_of!(*self)
79    }
80
81    pub fn as_ptr_mut(&mut self) -> *mut CString {
82        core::ptr::addr_of_mut!(*self)
83    }
84    /// # Safety
85    /// This function is unsafe because it dereferences a raw pointer.
86    pub unsafe fn to_cstr(&self) -> &CStr {
87        unsafe { std::ffi::CStr::from_ptr(self.data) }
88    }
89}
90
91impl AsRef<[i32]> for TensorShape {
92    fn as_ref(&self) -> &[i32] {
93        &self.shape[..self.size]
94    }
95}
96
97impl halide_type_code_t {
98    /// # Safety
99    /// This function is unsafe because this basically truansmutes an integer to an enum.
100    /// And if the enum is not valid, it will cause undefined behavior in rust.
101    pub unsafe fn from_u32(code: u32) -> Self {
102        unsafe { std::mem::transmute(code) }
103    }
104}