mnn/
interpreter.rs

1//! The interpreter module provides the `Interpreter` struct which is used to load and run models.
2use crate::tensor::list::TensorList;
3use std::{ffi::CStr, path::Path, sync::Arc};
4
5use crate::{
6    AsTensorShape, Device, RawTensor, Ref, RefMut, ScheduleConfig, Tensor, TensorType, prelude::*,
7};
8use mnn_sys::HalideType;
9
10pub(crate) type TensorCallbackT = Box<dyn Fn(&[RawTensor], OperatorInfo) -> bool>;
11
12#[repr(transparent)]
13pub(crate) struct TensorCallback {
14    inner: Arc<TensorCallbackT>,
15}
16
17impl Default for TensorCallback {
18    fn default() -> Self {
19        Self {
20            inner: Arc::new(Box::new(|_, _| true)),
21        }
22    }
23}
24
25impl TensorCallback {
26    pub(crate) fn from_ptr(f: *mut libc::c_void) -> Self {
27        debug_assert!(!f.is_null());
28        unsafe {
29            Self {
30                inner: Arc::from_raw(f.cast()),
31            }
32        }
33    }
34
35    pub(crate) fn into_ptr(self) -> *mut libc::c_void {
36        Arc::into_raw(self.inner) as *mut libc::c_void
37    }
38
39    #[cfg(test)]
40    pub(crate) fn identity() -> impl Fn(&[RawTensor], OperatorInfo) -> bool {
41        |_, _| true
42    }
43}
44
45impl<F> From<F> for TensorCallback
46where
47    F: Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
48{
49    fn from(f: F) -> Self {
50        Self {
51            inner: Arc::new(Box::new(f)),
52        }
53    }
54}
55
56impl<T> From<Option<T>> for TensorCallback
57where
58    T: Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
59{
60    fn from(f: Option<T>) -> Self {
61        match f {
62            Some(f) => Self {
63                inner: Arc::new(Box::new(f)),
64            },
65            None => Self::default(),
66        }
67    }
68}
69
70impl core::ops::Deref for TensorCallback {
71    type Target = TensorCallbackT;
72
73    fn deref(&self) -> &Self::Target {
74        &self.inner
75    }
76}
77
78/// The session mode to be used
79/// The items are mostly untested and are only documented 1:1 to the C++ codebase
80/// The only two items tested are
81/// - `Debug`
82/// - `Release`
83#[derive(Debug, Copy, Clone)]
84#[cfg_attr(windows, repr(i32))]
85#[cfg_attr(unix, repr(u32))]
86pub enum SessionMode {
87    #[doc = "About CallBack, Default Session_Debug*/\n/** runSessionWithCallBack is allowed and can get internal op info"]
88    Debug = mnn_sys::SessionMode::Session_Debug,
89    #[doc = "runSessionWithCallBack is not valid and can't get any info of op in\nsession"]
90    Release = mnn_sys::SessionMode::Session_Release,
91    #[doc = "About input tensor, Default Session_Input_Inside*/\n/** The input tensor is alloced by session, input data after session resized"]
92    InputInside = mnn_sys::SessionMode::Session_Input_Inside,
93    #[doc = "The input tensor is alloced by user, set input data before session\nresize"]
94    InputUser = mnn_sys::SessionMode::Session_Input_User,
95    #[doc = "The output tensor depends on session, and can't be separate used"]
96    OutputInside = mnn_sys::SessionMode::Session_Output_Inside,
97    #[doc = "The output tensor can be separated from session"]
98    OutputUser = mnn_sys::SessionMode::Session_Output_User,
99    #[doc = "Try Resize Session when create Session or not, default direct:"]
100    ResizeDirect = mnn_sys::SessionMode::Session_Resize_Direct,
101    #[doc = "Try Resize Session when create Session or not, default direct:"]
102    ResizeDefer = mnn_sys::SessionMode::Session_Resize_Defer,
103    #[doc = "Determine the Execution's forward type is determine by user or auto\ndetermine"]
104    BackendFix = mnn_sys::SessionMode::Session_Backend_Fix,
105    #[doc = "Determine the Execution's forward type is determine by user or auto\ndetermine"]
106    BackendAuto = mnn_sys::SessionMode::Session_Backend_Auto,
107    #[doc = "Determine static memory whether recyle in resizeSession or just cache the\nmemory"]
108    MemoryCollect = mnn_sys::SessionMode::Session_Memory_Collect,
109    #[doc = "Determine static memory whether recyle in resizeSession or just cache the\nmemory"]
110    MemoryCache = mnn_sys::SessionMode::Session_Memory_Cache,
111    #[doc = "Determine whether use codegen function"]
112    CodegenDisable = mnn_sys::SessionMode::Session_Codegen_Disable,
113    #[doc = "Determine whether use codegen function"]
114    CodegenEnable = mnn_sys::SessionMode::Session_Codegen_Enable,
115    #[doc = "Dynamic Reisze Optimization"]
116    ResizeCheck = mnn_sys::SessionMode::Session_Resize_Check,
117    #[doc = "Dynamic Reisze Optimization"]
118    ResizeFix = mnn_sys::SessionMode::Session_Resize_Fix,
119}
120
121#[cfg(windows)]
122type SessionModeType = i32;
123#[cfg(unix)]
124type SessionModeType = u32;
125
126impl SessionMode {
127    fn to_mnn_sys(self) -> SessionModeType {
128        self as SessionModeType
129    }
130}
131
132/// net data holder. multiple sessions could share same net.
133#[repr(transparent)]
134#[derive(Debug)]
135pub struct Interpreter {
136    pub(crate) inner: *mut mnn_sys::Interpreter,
137    pub(crate) __marker: PhantomData<()>,
138}
139
140unsafe impl Send for Interpreter {}
141
142impl Drop for Interpreter {
143    fn drop(&mut self) {
144        unsafe { mnn_sys::Interpreter_destroy(self.inner) }
145    }
146}
147
148impl Interpreter {
149    /// Create an net/interpreter from a file.
150    ///
151    /// `path`: the file path of the model
152    ///
153    /// return: the created net/interpreter
154    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
155        let path = path.as_ref();
156        ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string(), "File not found");
157        let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
158        let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
159        let interpreter = unsafe { mnn_sys::Interpreter_createFromFile(c_path.as_ptr()) };
160        ensure!(!interpreter.is_null(), ErrorKind::InterpreterError; "Failed to create interpreter", "Interpreter_createFromFile returned null");
161        Ok(Self {
162            inner: interpreter,
163            __marker: PhantomData,
164        })
165    }
166
167    /// Create an net/interpreter from a buffer.
168    ///
169    /// `bytes`: the buffer of the model
170    ///
171    /// return: the created net/interpreter
172    pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
173        let bytes = bytes.as_ref();
174        let size = bytes.len();
175        let interpreter =
176            unsafe { mnn_sys::Interpreter_createFromBuffer(bytes.as_ptr().cast(), size) };
177        ensure!(!interpreter.is_null(), ErrorKind::InterpreterError; "Failed to create interpreter", "Interpreter_createFromBuffer returned null");
178        Ok(Self {
179            inner: interpreter,
180            __marker: PhantomData,
181        })
182    }
183
184    /// Set session mode
185    ///
186    /// `mode`: the session mode
187    ///
188    /// **Warning:**
189    /// It should be called before create session!
190    pub fn set_session_mode(&mut self, mode: SessionMode) {
191        unsafe { mnn_sys::Interpreter_setSessionMode(self.inner, mode.to_mnn_sys()) }
192    }
193
194    ///call this function to get tensors ready.
195    ///
196    ///output tensor buffer (host or deviceId) should be retrieved after resize of any input tensor.
197    ///
198    ///`session`: the session to be prepared
199    pub fn resize_session(&self, session: &mut crate::Session) {
200        unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) }
201    }
202
203    /// Resize session and reallocate the buffer.
204    ///
205    /// `session`: the session to be prepared.
206    ///
207    /// # Note
208    /// NeedRelloc is default to 1, 1 means need realloc!
209    pub fn resize_session_reallocate(&self, session: &mut crate::Session) {
210        unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) }
211    }
212
213    /// Resize the tensor using the given shape
214    pub fn resize_tensor<T: TensorType>(&self, tensor: &mut Tensor<T>, dims: impl AsTensorShape) {
215        let dims = dims.as_tensor_shape();
216        let dims_len = dims.size;
217        unsafe {
218            mnn_sys::Interpreter_resizeTensor(
219                self.inner,
220                tensor.tensor,
221                dims.shape.as_ptr(),
222                dims_len,
223            )
224        }
225    }
226
227    /// Resize tensor by
228    /// - N -> batch
229    /// - C -> channel
230    /// - H -> height
231    /// - W -> width
232    pub fn resize_tensor_by_nchw<T: TensorType>(
233        &self,
234        tensor: &mut Tensor<T>,
235        batch: u16,
236        channel: u16,
237        height: u16,
238        width: u16,
239    ) {
240        unsafe {
241            mnn_sys::Interpreter_resizeTensorByNCHW(
242                self.inner,
243                tensor.tensor,
244                batch.into(),
245                channel.into(),
246                height.into(),
247                width.into(),
248            )
249        }
250    }
251
252    /// Create a session with session config. Session will be managed in net/interpreter.
253    ///
254    /// `schedule` : the config of the session
255    ///
256    /// return: the created session
257    pub fn create_session(
258        &mut self,
259        schedule: crate::ScheduleConfig,
260    ) -> Result<crate::session::Session> {
261        profile!("Creating session"; {
262            let session = unsafe { mnn_sys::Interpreter_createSession(self.inner, schedule.inner) };
263            assert!(!session.is_null());
264            Ok(crate::session::Session {
265                inner: session,
266                net: self.inner,
267                __session_internals: crate::SessionInternals::Single(schedule),
268                __marker: PhantomData,
269            })
270        })
271    }
272
273    /// Release the model file buffer
274    /// # Safety
275    /// This function is marked unsafe since it's not clear what the safety guarantees are right
276    /// now. With a simple test it caused a segfault so it's marked unsafe
277    pub unsafe fn release_model(&mut self) {
278        unsafe { mnn_sys::Interpreter_releaseModel(self.inner) }
279    }
280
281    /// Create multi-path session with schedule configs and user-specified runtime. created session will be managed in net/interpreter.
282    ///
283    /// `schedule` : the config of the session
284    ///
285    /// return: the created session
286    pub fn create_multipath_session(
287        &mut self,
288        schedule: impl IntoIterator<Item = ScheduleConfig>,
289    ) -> Result<crate::session::Session> {
290        profile!("Creating multipath session"; {
291            let schedules: crate::ScheduleConfigs = schedule.into_iter().collect();
292            let sc: &[_] = schedules.inner.as_ref();
293            let session = unsafe { mnn_sys::Interpreter_createMultiPathSession(self.inner, sc.as_ptr(), sc.len()) };
294            assert!(!session.is_null());
295            Ok(crate::session::Session {
296                inner: session,
297                net: self.inner,
298                __session_internals: crate::SessionInternals::MultiSession(schedules),
299                __marker: PhantomData,
300            })
301        })
302    }
303
304    /// Print all input and output tensors info.
305    pub fn model_print_io(path: impl AsRef<Path>) -> Result<()> {
306        let path = path.as_ref();
307        crate::ensure!(path.exists(), ErrorKind::IOError);
308        let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
309        let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
310        unsafe { mnn_sys::modelPrintIO(c_path.as_ptr()) }
311        Ok(())
312    }
313
314    /// Get the input tensor of the session.
315    ///
316    /// `session`: the session to get input tensor
317    ///
318    /// return: List of input tensors
319    pub fn inputs<'i>(&self, session: &'i crate::Session) -> TensorList<'i> {
320        let inputs = unsafe { mnn_sys::Interpreter_getSessionInputAll(self.inner, session.inner) };
321        TensorList::from_ptr(inputs)
322    }
323
324    /// Get the input tensor of the session by name.
325    ///
326    /// `session`: the session to get input tensor from
327    ///
328    /// `name`: the name of the input tensor
329    ///
330    /// return: the input tensor
331    pub fn input<'s, H: HalideType>(
332        &self,
333        session: &'s crate::Session,
334        name: impl AsRef<str>,
335    ) -> Result<Tensor<RefMut<'s, Device<H>>>> {
336        let name = name.as_ref();
337        let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
338        let input = unsafe {
339            mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
340        };
341        ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
342        let tensor = unsafe { Tensor::from_ptr(input) };
343        let shape = tensor.shape();
344        ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError);
345        ensure!(
346            tensor.is_type_of::<H>(),
347            ErrorKind::HalideTypeMismatch {
348                got: std::any::type_name::<H>(),
349            };
350            format!("Input tensor \"{name}\" is not of type {}", std::any::type_name::<H>())
351        );
352        Ok(tensor)
353    }
354
355    /// Get the raw input tensor of a session by name
356    pub fn raw_input<'s>(
357        &self,
358        session: &'s crate::Session,
359        name: impl AsRef<str>,
360    ) -> Result<RawTensor<'s>> {
361        let name = name.as_ref();
362        let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
363        let input = unsafe {
364            mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
365        };
366        ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
367        Ok(RawTensor::from_ptr(input))
368    }
369
370    /// # Safety
371    /// **Warning**  We Still don't know the safety guarantees of this function so it's marked unsafe
372    pub unsafe fn input_unresized<'s, H: HalideType>(
373        &self,
374        session: &'s crate::Session,
375        name: impl AsRef<str>,
376    ) -> Result<Tensor<RefMut<'s, Device<H>>>> {
377        let name = name.as_ref();
378        let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
379        let input = unsafe {
380            mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
381        };
382        ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
383        let tensor = unsafe { Tensor::from_ptr(input) };
384        ensure!(
385            tensor.is_type_of::<H>(),
386            ErrorKind::HalideTypeMismatch {
387                got: std::any::type_name::<H>(),
388            }
389        );
390        Ok(tensor)
391    }
392
393    /// # Safety
394    /// Very **unsafe** since it doesn't check the type of the tensor
395    /// as well as the shape of the tensor
396    ///
397    /// **Panics** if the name is not ascii
398    /// **Undefined Behavior** if the tensor is not of type `H`
399    pub unsafe fn input_unchecked<'s, H: HalideType>(
400        &self,
401        session: &'s crate::Session,
402        name: impl AsRef<str>,
403    ) -> Tensor<RefMut<'s, Device<H>>> {
404        let name = name.as_ref();
405        let c_name = std::ffi::CString::new(name).expect("Input tensor name is not ascii");
406        unsafe {
407            let input =
408                mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr());
409            Tensor::from_ptr(input)
410        }
411    }
412
413    /// Get the output tensor of a session by name
414    ///
415    /// `session` : the session to get output tensor from
416    ///
417    /// `name` : the name of the output tensor
418    pub fn output<'s, H: HalideType>(
419        &self,
420        session: &'s crate::Session,
421        name: impl AsRef<str>,
422    ) -> Result<Tensor<Ref<'s, Device<H>>>> {
423        let name = name.as_ref();
424        let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
425        let output = unsafe {
426            mnn_sys::Interpreter_getSessionOutput(self.inner, session.inner, c_name.as_ptr())
427        };
428        ensure!(!output.is_null(), ErrorKind::IOError;format!("Output tensor \"{name}\" not found"));
429        let tensor = unsafe { Tensor::from_ptr(output) };
430        let shape = tensor.shape();
431        ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError);
432        ensure!(
433            tensor.is_type_of::<H>(),
434            ErrorKind::HalideTypeMismatch {
435                got: std::any::type_name::<H>(),
436            }
437        );
438        Ok(tensor)
439    }
440
441    /// Get the raw output tensor of a session by name
442    pub fn raw_output<'s>(
443        &self,
444        session: &'s crate::Session,
445        name: impl AsRef<str>,
446    ) -> Result<RawTensor<'s>> {
447        let name = name.as_ref();
448        let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
449        let output = unsafe {
450            mnn_sys::Interpreter_getSessionOutput(self.inner, session.inner, c_name.as_ptr())
451        };
452        ensure!(!output.is_null(), ErrorKind::IOError;format!("Output tensor \"{name}\" not found"));
453        Ok(RawTensor::from_ptr(output))
454    }
455
456    /// Run a session
457    pub fn run_session(&mut self, session: &crate::session::Session) -> Result<()> {
458        profile!("Running session"; {
459            let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
460            ensure!(
461                ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR,
462                ErrorKind::InternalError(ret)
463            );
464            Ok(())
465        })
466    }
467
468    /// Run a session with a callback
469    ///
470    /// `session` : the session to run
471    ///
472    /// `before` : a callback before each op. return true to run the op; return false to skip the op.
473    ///
474    /// `after` : a callback after each op. return true to continue running; return false to interrupt the session.
475    ///
476    /// `sync` : synchronously wait for finish of execution or not.
477    pub fn run_session_with_callback(
478        &mut self,
479        session: &crate::session::Session,
480        before: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
481        end: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
482        sync: bool,
483    ) -> Result<()> {
484        let sync = sync as libc::c_int;
485        let before = TensorCallback::from(before).into_ptr();
486        let end = TensorCallback::from(end).into_ptr();
487        let ret = unsafe {
488            mnn_sys::Interpreter_runSessionWithCallBackInfo(
489                self.inner,
490                session.inner,
491                before,
492                end,
493                sync,
494            )
495        };
496        ensure!(
497            ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR,
498            ErrorKind::InternalError(ret)
499        );
500        Ok(())
501    }
502
503    /// Get all output tensors of a session
504    pub fn outputs<'o>(&self, session: &'o crate::session::Session) -> TensorList<'o> {
505        let outputs =
506            unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) };
507        TensorList::from_ptr(outputs)
508    }
509
510    /// If the cache exist, try to load cache from file.
511    /// After createSession, try to save cache to file.
512    ///
513    /// `cache_file` : the file path to save or load cache.
514    ///
515    /// `key_size` : the size of key
516    ///
517    /// # Note
518    /// The API should be called before create session.
519    ///
520    /// Key Depercerate, keeping for future use!
521    pub fn set_cache_file(&mut self, path: impl AsRef<Path>, key_size: usize) -> Result<()> {
522        let path = path.as_ref();
523        let path = dunce::simplified(path);
524        let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
525        let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
526        unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) }
527        Ok(())
528    }
529
530    /// Update cache file
531    pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> {
532        MNNError::from_error_code(unsafe {
533            mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner)
534        });
535        Ok(())
536    }
537
538    /// Wait for all output tensors to be ready after computation
539    pub fn wait(&self, session: &crate::session::Session) {
540        self.outputs(session).iter().for_each(|tinfo| {
541            tinfo
542                .raw_tensor()
543                .wait(mnn_sys::MapType::MAP_TENSOR_READ, true);
544        });
545    }
546
547    /// Get memory usage of a session in MB
548    pub fn memory(&self, session: &crate::session::Session) -> Result<f32> {
549        let mut memory = 0f32;
550        let memory_ptr = &mut memory as *mut f32;
551        let ret = unsafe {
552            mnn_sys::Interpreter_getSessionInfo(
553                self.inner,
554                session.inner,
555                mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_MEMORY as _,
556                memory_ptr.cast(),
557            )
558        };
559        ensure!(
560            ret == 1,
561            ErrorKind::InterpreterError;
562            "Failed to get memory usage"
563        );
564        Ok(memory)
565    }
566
567    /// Get float operation needed in session in M
568    pub fn flops(&self, session: &crate::Session) -> Result<f32> {
569        let mut flop = 0.0f32;
570        let flop_ptr = &mut flop as *mut f32;
571        let ret = unsafe {
572            mnn_sys::Interpreter_getSessionInfo(
573                self.inner,
574                session.inner,
575                mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_FLOPS as _,
576                flop_ptr.cast::<libc::c_void>(),
577            )
578        };
579        ensure!(
580            ret == 1,
581            ErrorKind::InterpreterError;
582            "Failed to get flops"
583        );
584        Ok(flop)
585    }
586
587    /// Get the resize status
588    pub fn resize_status(&self, session: &crate::Session) -> Result<ResizeStatus> {
589        let mut resize_status = 0i32;
590        let ptr = &mut resize_status as *mut i32;
591        let ret = unsafe {
592            mnn_sys::Interpreter_getSessionInfo(
593                self.inner,
594                session.inner,
595                mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_RESIZE_STATUS as _,
596                ptr.cast(),
597            )
598        };
599        ensure!(
600        ret == 1,
601            ErrorKind::InterpreterError;
602            "Failed to get resize status"
603        );
604        match resize_status {
605            0 => Ok(ResizeStatus::None),
606            1 => Ok(ResizeStatus::NeedMalloc),
607            2 => Ok(ResizeStatus::NeedResize),
608            _ => Err(error!(ErrorKind::InterpreterError)),
609        }
610    }
611}
612
613/// The status of the resize operation
614#[derive(Debug, Copy, Clone, PartialEq, Eq)]
615#[repr(C)]
616pub enum ResizeStatus {
617    /// No resize needed
618    None = 0,
619    /// Need to malloc memory
620    NeedMalloc = 1,
621    /// Need to resize memory
622    NeedResize = 2,
623}
624
625#[unsafe(no_mangle)]
626extern "C" fn rust_closure_callback_runner_op(
627    f: *mut libc::c_void,
628    tensors: *const *mut mnn_sys::Tensor,
629    tensor_count: usize,
630    op: *mut libc::c_void,
631) -> libc::c_int {
632    let tensors = unsafe { std::slice::from_raw_parts(tensors.cast(), tensor_count) };
633    let f: TensorCallback = TensorCallback::from_ptr(f);
634    let op = OperatorInfo {
635        inner: op.cast(),
636        __marker: PhantomData,
637    };
638    let ret = f(tensors, op) as libc::c_int;
639
640    core::mem::forget(f);
641    ret
642}
643
644/// A struct that holds information about an operator
645#[repr(transparent)]
646pub struct OperatorInfo<'op> {
647    pub(crate) inner: *mut libc::c_void,
648    pub(crate) __marker: PhantomData<&'op ()>,
649}
650
651impl core::fmt::Debug for OperatorInfo<'_> {
652    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
653        f.debug_struct("OperatorInfo")
654            .field("name", &self.name())
655            .field("type", &self.type_name())
656            .field("flops", &self.flops())
657            .finish()
658    }
659}
660
661impl OperatorInfo<'_> {
662    /// Get the name of the operator
663    pub fn name(&self) -> &CStr {
664        unsafe { CStr::from_ptr(mnn_sys::OperatorInfo_name(self.inner)) }
665    }
666
667    /// Get the type of the operator
668    pub fn type_name(&self) -> &CStr {
669        unsafe { CStr::from_ptr(mnn_sys::OperatorInfo_type(self.inner)) }
670    }
671
672    /// Get the number of flops of the operator
673    pub fn flops(&self) -> f32 {
674        unsafe { mnn_sys::OperatorInfo_flops(self.inner) }
675    }
676}
677
678#[test]
679#[ignore = "This test doesn't work in CI"]
680fn test_run_session_with_callback_info_api() {
681    let file = Path::new("tests/assets/realesr.mnn")
682        .canonicalize()
683        .unwrap();
684    let mut interpreter = Interpreter::from_file(&file).unwrap();
685    let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
686    interpreter
687        .run_session_with_callback(
688            &session,
689            TensorCallback::identity(),
690            TensorCallback::identity(),
691            true,
692        )
693        .unwrap();
694}
695
696#[test]
697#[ignore = "This test doesn't work in CI"]
698fn check_whether_sync_actually_works() {
699    let file = Path::new("tests/assets/realesr.mnn")
700        .canonicalize()
701        .unwrap();
702    let mut interpreter = Interpreter::from_file(&file).unwrap();
703    let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
704    let time = std::time::Instant::now();
705    interpreter
706        .run_session_with_callback(
707            &session,
708            TensorCallback::identity(),
709            TensorCallback::identity(),
710            false,
711        )
712        .unwrap();
713    let time = time.elapsed();
714    let time2 = std::time::Instant::now();
715    interpreter
716        .run_session_with_callback(
717            &session,
718            TensorCallback::identity(),
719            TensorCallback::identity(),
720            true,
721        )
722        .unwrap();
723    let time2 = time2.elapsed();
724    assert!((time - time2) > std::time::Duration::from_millis(50));
725}
726
727#[test]
728#[ignore = "Fails on CI"]
729fn try_to_drop_interpreter_before_session() {
730    let file = Path::new("tests/assets/realesr.mnn")
731        .canonicalize()
732        .unwrap();
733    let mut interpreter = Interpreter::from_file(&file).unwrap();
734    let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
735    drop(interpreter);
736    drop(session);
737}