1use std::ffi::CStr;
2mod tracing;
3
4pub mod cpp {
5 #![allow(non_upper_case_globals)]
6 #![allow(non_camel_case_types)]
7 #![allow(non_snake_case)]
8 include!(concat!(env!("OUT_DIR"), "/mnn_cpp.rs"));
9}
10mod sys {
11 #![allow(non_upper_case_globals)]
12 #![allow(non_camel_case_types)]
13 #![allow(non_snake_case)]
14 #![allow(clippy::manual_c_str_literals)]
15 #![allow(clippy::suspicious_doc_comments)]
16 include!(concat!(env!("OUT_DIR"), "/mnn_c.rs"));
17}
18pub use sys::*;
19impl DimensionType {
20 pub const NHWC: Self = Self::TENSORFLOW;
21 pub const NCHW: Self = Self::CAFFE;
22 pub const NC4HW4: Self = Self::CAFFE_C4;
23}
24impl halide_type_t {
25 unsafe fn new(code: halide_type_code_t, bits: u8, lanes: u16) -> Self {
26 Self { code, bits, lanes }
27 }
28}
29
30pub fn halide_type_of<T: HalideType>() -> halide_type_t {
31 T::halide_type_of()
32}
33
34pub trait HalideType: seal::Sealed {
35 fn halide_type_of() -> halide_type_t;
36}
37mod seal {
38 pub trait Sealed {}
39}
40
41macro_rules! halide_types {
42 ($($t:ty => $ht:expr),*) => {
43 $(
44 impl seal::Sealed for $t {}
45 impl HalideType for $t {
46 fn halide_type_of() -> halide_type_t {
47 unsafe {
48 $ht
49 }
50 }
51 }
52 )*
53 };
54}
55
56halide_types! {
57 f32 => halide_type_t::new(halide_type_code_t::halide_type_float, 32, 1),
58 f64 => halide_type_t::new(halide_type_code_t::halide_type_float, 64, 1),
59 bool => halide_type_t::new(halide_type_code_t::halide_type_uint, 1, 1),
60 u8 => halide_type_t::new(halide_type_code_t::halide_type_uint, 8,1),
61 u16 => halide_type_t::new(halide_type_code_t::halide_type_uint, 16,1),
62 u32 => halide_type_t::new(halide_type_code_t::halide_type_uint, 32,1),
63 u64 => halide_type_t::new(halide_type_code_t::halide_type_uint, 64,1),
64 i8 => halide_type_t::new(halide_type_code_t::halide_type_int, 8,1),
65 i16 => halide_type_t::new(halide_type_code_t::halide_type_int, 16,1),
66 i32 => halide_type_t::new(halide_type_code_t::halide_type_int, 32,1),
67 i64 => halide_type_t::new(halide_type_code_t::halide_type_int, 64,1)
68}
69
70impl Drop for CString {
71 fn drop(&mut self) {
72 unsafe { destroyCString(self.as_ptr_mut()) }
73 }
74}
75
76impl CString {
77 pub fn as_ptr(&self) -> *const CString {
78 core::ptr::addr_of!(*self)
79 }
80
81 pub fn as_ptr_mut(&mut self) -> *mut CString {
82 core::ptr::addr_of_mut!(*self)
83 }
84 pub unsafe fn to_cstr(&self) -> &CStr {
87 unsafe { std::ffi::CStr::from_ptr(self.data) }
88 }
89}
90
91impl AsRef<[i32]> for TensorShape {
92 fn as_ref(&self) -> &[i32] {
93 &self.shape[..self.size]
94 }
95}
96
97impl halide_type_code_t {
98 pub unsafe fn from_u32(code: u32) -> Self {
102 unsafe { std::mem::transmute(code) }
103 }
104}