1use crate::tensor::list::TensorList;
3use std::{ffi::CStr, path::Path, sync::Arc};
4
5use crate::{
6 AsTensorShape, Device, RawTensor, Ref, RefMut, ScheduleConfig, Tensor, TensorType, prelude::*,
7};
8use mnn_sys::HalideType;
9
10pub(crate) type TensorCallbackT = Box<dyn Fn(&[RawTensor], OperatorInfo) -> bool>;
11
12#[repr(transparent)]
13pub(crate) struct TensorCallback {
14 inner: Arc<TensorCallbackT>,
15}
16
17impl Default for TensorCallback {
18 fn default() -> Self {
19 Self {
20 inner: Arc::new(Box::new(|_, _| true)),
21 }
22 }
23}
24
25impl TensorCallback {
26 pub(crate) fn from_ptr(f: *mut libc::c_void) -> Self {
27 debug_assert!(!f.is_null());
28 unsafe {
29 Self {
30 inner: Arc::from_raw(f.cast()),
31 }
32 }
33 }
34
35 pub(crate) fn into_ptr(self) -> *mut libc::c_void {
36 Arc::into_raw(self.inner) as *mut libc::c_void
37 }
38
39 #[cfg(test)]
40 pub(crate) fn identity() -> impl Fn(&[RawTensor], OperatorInfo) -> bool {
41 |_, _| true
42 }
43}
44
45impl<F> From<F> for TensorCallback
46where
47 F: Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
48{
49 fn from(f: F) -> Self {
50 Self {
51 inner: Arc::new(Box::new(f)),
52 }
53 }
54}
55
56impl<T> From<Option<T>> for TensorCallback
57where
58 T: Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
59{
60 fn from(f: Option<T>) -> Self {
61 match f {
62 Some(f) => Self {
63 inner: Arc::new(Box::new(f)),
64 },
65 None => Self::default(),
66 }
67 }
68}
69
70impl core::ops::Deref for TensorCallback {
71 type Target = TensorCallbackT;
72
73 fn deref(&self) -> &Self::Target {
74 &self.inner
75 }
76}
77
78#[derive(Debug, Copy, Clone)]
84#[cfg_attr(windows, repr(i32))]
85#[cfg_attr(unix, repr(u32))]
86pub enum SessionMode {
87 #[doc = "About CallBack, Default Session_Debug*/\n/** runSessionWithCallBack is allowed and can get internal op info"]
88 Debug = mnn_sys::SessionMode::Session_Debug,
89 #[doc = "runSessionWithCallBack is not valid and can't get any info of op in\nsession"]
90 Release = mnn_sys::SessionMode::Session_Release,
91 #[doc = "About input tensor, Default Session_Input_Inside*/\n/** The input tensor is alloced by session, input data after session resized"]
92 InputInside = mnn_sys::SessionMode::Session_Input_Inside,
93 #[doc = "The input tensor is alloced by user, set input data before session\nresize"]
94 InputUser = mnn_sys::SessionMode::Session_Input_User,
95 #[doc = "The output tensor depends on session, and can't be separate used"]
96 OutputInside = mnn_sys::SessionMode::Session_Output_Inside,
97 #[doc = "The output tensor can be separated from session"]
98 OutputUser = mnn_sys::SessionMode::Session_Output_User,
99 #[doc = "Try Resize Session when create Session or not, default direct:"]
100 ResizeDirect = mnn_sys::SessionMode::Session_Resize_Direct,
101 #[doc = "Try Resize Session when create Session or not, default direct:"]
102 ResizeDefer = mnn_sys::SessionMode::Session_Resize_Defer,
103 #[doc = "Determine the Execution's forward type is determine by user or auto\ndetermine"]
104 BackendFix = mnn_sys::SessionMode::Session_Backend_Fix,
105 #[doc = "Determine the Execution's forward type is determine by user or auto\ndetermine"]
106 BackendAuto = mnn_sys::SessionMode::Session_Backend_Auto,
107 #[doc = "Determine static memory whether recyle in resizeSession or just cache the\nmemory"]
108 MemoryCollect = mnn_sys::SessionMode::Session_Memory_Collect,
109 #[doc = "Determine static memory whether recyle in resizeSession or just cache the\nmemory"]
110 MemoryCache = mnn_sys::SessionMode::Session_Memory_Cache,
111 #[doc = "Determine whether use codegen function"]
112 CodegenDisable = mnn_sys::SessionMode::Session_Codegen_Disable,
113 #[doc = "Determine whether use codegen function"]
114 CodegenEnable = mnn_sys::SessionMode::Session_Codegen_Enable,
115 #[doc = "Dynamic Reisze Optimization"]
116 ResizeCheck = mnn_sys::SessionMode::Session_Resize_Check,
117 #[doc = "Dynamic Reisze Optimization"]
118 ResizeFix = mnn_sys::SessionMode::Session_Resize_Fix,
119}
120
121#[cfg(windows)]
122type SessionModeType = i32;
123#[cfg(unix)]
124type SessionModeType = u32;
125
126impl SessionMode {
127 fn to_mnn_sys(self) -> SessionModeType {
128 self as SessionModeType
129 }
130}
131
132#[repr(transparent)]
134#[derive(Debug)]
135pub struct Interpreter {
136 pub(crate) inner: *mut mnn_sys::Interpreter,
137 pub(crate) __marker: PhantomData<()>,
138}
139
140unsafe impl Send for Interpreter {}
141
142impl Drop for Interpreter {
143 fn drop(&mut self) {
144 unsafe { mnn_sys::Interpreter_destroy(self.inner) }
145 }
146}
147
148impl Interpreter {
149 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
155 let path = path.as_ref();
156 ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string(), "File not found");
157 let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
158 let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
159 let interpreter = unsafe { mnn_sys::Interpreter_createFromFile(c_path.as_ptr()) };
160 ensure!(!interpreter.is_null(), ErrorKind::InterpreterError; "Failed to create interpreter", "Interpreter_createFromFile returned null");
161 Ok(Self {
162 inner: interpreter,
163 __marker: PhantomData,
164 })
165 }
166
167 pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
173 let bytes = bytes.as_ref();
174 let size = bytes.len();
175 let interpreter =
176 unsafe { mnn_sys::Interpreter_createFromBuffer(bytes.as_ptr().cast(), size) };
177 ensure!(!interpreter.is_null(), ErrorKind::InterpreterError; "Failed to create interpreter", "Interpreter_createFromBuffer returned null");
178 Ok(Self {
179 inner: interpreter,
180 __marker: PhantomData,
181 })
182 }
183
184 pub fn set_session_mode(&mut self, mode: SessionMode) {
191 unsafe { mnn_sys::Interpreter_setSessionMode(self.inner, mode.to_mnn_sys()) }
192 }
193
194 pub fn resize_session(&self, session: &mut crate::Session) {
200 unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) }
201 }
202
203 pub fn resize_session_reallocate(&self, session: &mut crate::Session) {
210 unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) }
211 }
212
213 pub fn resize_tensor<T: TensorType>(&self, tensor: &mut Tensor<T>, dims: impl AsTensorShape) {
215 let dims = dims.as_tensor_shape();
216 let dims_len = dims.size;
217 unsafe {
218 mnn_sys::Interpreter_resizeTensor(
219 self.inner,
220 tensor.tensor,
221 dims.shape.as_ptr(),
222 dims_len,
223 )
224 }
225 }
226
227 pub fn resize_tensor_by_nchw<T: TensorType>(
233 &self,
234 tensor: &mut Tensor<T>,
235 batch: u16,
236 channel: u16,
237 height: u16,
238 width: u16,
239 ) {
240 unsafe {
241 mnn_sys::Interpreter_resizeTensorByNCHW(
242 self.inner,
243 tensor.tensor,
244 batch.into(),
245 channel.into(),
246 height.into(),
247 width.into(),
248 )
249 }
250 }
251
252 pub fn create_session(
258 &mut self,
259 schedule: crate::ScheduleConfig,
260 ) -> Result<crate::session::Session> {
261 profile!("Creating session"; {
262 let session = unsafe { mnn_sys::Interpreter_createSession(self.inner, schedule.inner) };
263 assert!(!session.is_null());
264 Ok(crate::session::Session {
265 inner: session,
266 net: self.inner,
267 __session_internals: crate::SessionInternals::Single(schedule),
268 __marker: PhantomData,
269 })
270 })
271 }
272
273 pub unsafe fn release_model(&mut self) {
278 unsafe { mnn_sys::Interpreter_releaseModel(self.inner) }
279 }
280
281 pub fn create_multipath_session(
287 &mut self,
288 schedule: impl IntoIterator<Item = ScheduleConfig>,
289 ) -> Result<crate::session::Session> {
290 profile!("Creating multipath session"; {
291 let schedules: crate::ScheduleConfigs = schedule.into_iter().collect();
292 let sc: &[_] = schedules.inner.as_ref();
293 let session = unsafe { mnn_sys::Interpreter_createMultiPathSession(self.inner, sc.as_ptr(), sc.len()) };
294 assert!(!session.is_null());
295 Ok(crate::session::Session {
296 inner: session,
297 net: self.inner,
298 __session_internals: crate::SessionInternals::MultiSession(schedules),
299 __marker: PhantomData,
300 })
301 })
302 }
303
304 pub fn model_print_io(path: impl AsRef<Path>) -> Result<()> {
306 let path = path.as_ref();
307 crate::ensure!(path.exists(), ErrorKind::IOError);
308 let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
309 let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
310 unsafe { mnn_sys::modelPrintIO(c_path.as_ptr()) }
311 Ok(())
312 }
313
314 pub fn inputs<'i>(&self, session: &'i crate::Session) -> TensorList<'i> {
320 let inputs = unsafe { mnn_sys::Interpreter_getSessionInputAll(self.inner, session.inner) };
321 TensorList::from_ptr(inputs)
322 }
323
324 pub fn input<'s, H: HalideType>(
332 &self,
333 session: &'s crate::Session,
334 name: impl AsRef<str>,
335 ) -> Result<Tensor<RefMut<'s, Device<H>>>> {
336 let name = name.as_ref();
337 let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
338 let input = unsafe {
339 mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
340 };
341 ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
342 let tensor = unsafe { Tensor::from_ptr(input) };
343 let shape = tensor.shape();
344 ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError);
345 ensure!(
346 tensor.is_type_of::<H>(),
347 ErrorKind::HalideTypeMismatch {
348 got: std::any::type_name::<H>(),
349 };
350 format!("Input tensor \"{name}\" is not of type {}", std::any::type_name::<H>())
351 );
352 Ok(tensor)
353 }
354
355 pub fn raw_input<'s>(
357 &self,
358 session: &'s crate::Session,
359 name: impl AsRef<str>,
360 ) -> Result<RawTensor<'s>> {
361 let name = name.as_ref();
362 let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
363 let input = unsafe {
364 mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
365 };
366 ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
367 Ok(RawTensor::from_ptr(input))
368 }
369
370 pub unsafe fn input_unresized<'s, H: HalideType>(
373 &self,
374 session: &'s crate::Session,
375 name: impl AsRef<str>,
376 ) -> Result<Tensor<RefMut<'s, Device<H>>>> {
377 let name = name.as_ref();
378 let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
379 let input = unsafe {
380 mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
381 };
382 ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
383 let tensor = unsafe { Tensor::from_ptr(input) };
384 ensure!(
385 tensor.is_type_of::<H>(),
386 ErrorKind::HalideTypeMismatch {
387 got: std::any::type_name::<H>(),
388 }
389 );
390 Ok(tensor)
391 }
392
393 pub unsafe fn input_unchecked<'s, H: HalideType>(
400 &self,
401 session: &'s crate::Session,
402 name: impl AsRef<str>,
403 ) -> Tensor<RefMut<'s, Device<H>>> {
404 let name = name.as_ref();
405 let c_name = std::ffi::CString::new(name).expect("Input tensor name is not ascii");
406 unsafe {
407 let input =
408 mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr());
409 Tensor::from_ptr(input)
410 }
411 }
412
413 pub fn output<'s, H: HalideType>(
419 &self,
420 session: &'s crate::Session,
421 name: impl AsRef<str>,
422 ) -> Result<Tensor<Ref<'s, Device<H>>>> {
423 let name = name.as_ref();
424 let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
425 let output = unsafe {
426 mnn_sys::Interpreter_getSessionOutput(self.inner, session.inner, c_name.as_ptr())
427 };
428 ensure!(!output.is_null(), ErrorKind::IOError;format!("Output tensor \"{name}\" not found"));
429 let tensor = unsafe { Tensor::from_ptr(output) };
430 let shape = tensor.shape();
431 ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError);
432 ensure!(
433 tensor.is_type_of::<H>(),
434 ErrorKind::HalideTypeMismatch {
435 got: std::any::type_name::<H>(),
436 }
437 );
438 Ok(tensor)
439 }
440
441 pub fn raw_output<'s>(
443 &self,
444 session: &'s crate::Session,
445 name: impl AsRef<str>,
446 ) -> Result<RawTensor<'s>> {
447 let name = name.as_ref();
448 let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
449 let output = unsafe {
450 mnn_sys::Interpreter_getSessionOutput(self.inner, session.inner, c_name.as_ptr())
451 };
452 ensure!(!output.is_null(), ErrorKind::IOError;format!("Output tensor \"{name}\" not found"));
453 Ok(RawTensor::from_ptr(output))
454 }
455
456 pub fn run_session(&mut self, session: &crate::session::Session) -> Result<()> {
458 profile!("Running session"; {
459 let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
460 ensure!(
461 ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR,
462 ErrorKind::InternalError(ret)
463 );
464 Ok(())
465 })
466 }
467
468 pub fn run_session_with_callback(
478 &mut self,
479 session: &crate::session::Session,
480 before: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
481 end: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
482 sync: bool,
483 ) -> Result<()> {
484 let sync = sync as libc::c_int;
485 let before = TensorCallback::from(before).into_ptr();
486 let end = TensorCallback::from(end).into_ptr();
487 let ret = unsafe {
488 mnn_sys::Interpreter_runSessionWithCallBackInfo(
489 self.inner,
490 session.inner,
491 before,
492 end,
493 sync,
494 )
495 };
496 ensure!(
497 ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR,
498 ErrorKind::InternalError(ret)
499 );
500 Ok(())
501 }
502
503 pub fn outputs<'o>(&self, session: &'o crate::session::Session) -> TensorList<'o> {
505 let outputs =
506 unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) };
507 TensorList::from_ptr(outputs)
508 }
509
510 pub fn set_cache_file(&mut self, path: impl AsRef<Path>, key_size: usize) -> Result<()> {
522 let path = path.as_ref();
523 let path = dunce::simplified(path);
524 let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
525 let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
526 unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) }
527 Ok(())
528 }
529
530 pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> {
532 MNNError::from_error_code(unsafe {
533 mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner)
534 });
535 Ok(())
536 }
537
538 pub fn wait(&self, session: &crate::session::Session) {
540 self.outputs(session).iter().for_each(|tinfo| {
541 tinfo
542 .raw_tensor()
543 .wait(mnn_sys::MapType::MAP_TENSOR_READ, true);
544 });
545 }
546
547 pub fn memory(&self, session: &crate::session::Session) -> Result<f32> {
549 let mut memory = 0f32;
550 let memory_ptr = &mut memory as *mut f32;
551 let ret = unsafe {
552 mnn_sys::Interpreter_getSessionInfo(
553 self.inner,
554 session.inner,
555 mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_MEMORY as _,
556 memory_ptr.cast(),
557 )
558 };
559 ensure!(
560 ret == 1,
561 ErrorKind::InterpreterError;
562 "Failed to get memory usage"
563 );
564 Ok(memory)
565 }
566
567 pub fn flops(&self, session: &crate::Session) -> Result<f32> {
569 let mut flop = 0.0f32;
570 let flop_ptr = &mut flop as *mut f32;
571 let ret = unsafe {
572 mnn_sys::Interpreter_getSessionInfo(
573 self.inner,
574 session.inner,
575 mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_FLOPS as _,
576 flop_ptr.cast::<libc::c_void>(),
577 )
578 };
579 ensure!(
580 ret == 1,
581 ErrorKind::InterpreterError;
582 "Failed to get flops"
583 );
584 Ok(flop)
585 }
586
587 pub fn resize_status(&self, session: &crate::Session) -> Result<ResizeStatus> {
589 let mut resize_status = 0i32;
590 let ptr = &mut resize_status as *mut i32;
591 let ret = unsafe {
592 mnn_sys::Interpreter_getSessionInfo(
593 self.inner,
594 session.inner,
595 mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_RESIZE_STATUS as _,
596 ptr.cast(),
597 )
598 };
599 ensure!(
600 ret == 1,
601 ErrorKind::InterpreterError;
602 "Failed to get resize status"
603 );
604 match resize_status {
605 0 => Ok(ResizeStatus::None),
606 1 => Ok(ResizeStatus::NeedMalloc),
607 2 => Ok(ResizeStatus::NeedResize),
608 _ => Err(error!(ErrorKind::InterpreterError)),
609 }
610 }
611}
612
613#[derive(Debug, Copy, Clone, PartialEq, Eq)]
615#[repr(C)]
616pub enum ResizeStatus {
617 None = 0,
619 NeedMalloc = 1,
621 NeedResize = 2,
623}
624
625#[unsafe(no_mangle)]
626extern "C" fn rust_closure_callback_runner_op(
627 f: *mut libc::c_void,
628 tensors: *const *mut mnn_sys::Tensor,
629 tensor_count: usize,
630 op: *mut libc::c_void,
631) -> libc::c_int {
632 let tensors = unsafe { std::slice::from_raw_parts(tensors.cast(), tensor_count) };
633 let f: TensorCallback = TensorCallback::from_ptr(f);
634 let op = OperatorInfo {
635 inner: op.cast(),
636 __marker: PhantomData,
637 };
638 let ret = f(tensors, op) as libc::c_int;
639
640 core::mem::forget(f);
641 ret
642}
643
644#[repr(transparent)]
646pub struct OperatorInfo<'op> {
647 pub(crate) inner: *mut libc::c_void,
648 pub(crate) __marker: PhantomData<&'op ()>,
649}
650
651impl core::fmt::Debug for OperatorInfo<'_> {
652 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
653 f.debug_struct("OperatorInfo")
654 .field("name", &self.name())
655 .field("type", &self.type_name())
656 .field("flops", &self.flops())
657 .finish()
658 }
659}
660
661impl OperatorInfo<'_> {
662 pub fn name(&self) -> &CStr {
664 unsafe { CStr::from_ptr(mnn_sys::OperatorInfo_name(self.inner)) }
665 }
666
667 pub fn type_name(&self) -> &CStr {
669 unsafe { CStr::from_ptr(mnn_sys::OperatorInfo_type(self.inner)) }
670 }
671
672 pub fn flops(&self) -> f32 {
674 unsafe { mnn_sys::OperatorInfo_flops(self.inner) }
675 }
676}
677
678#[test]
679#[ignore = "This test doesn't work in CI"]
680fn test_run_session_with_callback_info_api() {
681 let file = Path::new("tests/assets/realesr.mnn")
682 .canonicalize()
683 .unwrap();
684 let mut interpreter = Interpreter::from_file(&file).unwrap();
685 let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
686 interpreter
687 .run_session_with_callback(
688 &session,
689 TensorCallback::identity(),
690 TensorCallback::identity(),
691 true,
692 )
693 .unwrap();
694}
695
696#[test]
697#[ignore = "This test doesn't work in CI"]
698fn check_whether_sync_actually_works() {
699 let file = Path::new("tests/assets/realesr.mnn")
700 .canonicalize()
701 .unwrap();
702 let mut interpreter = Interpreter::from_file(&file).unwrap();
703 let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
704 let time = std::time::Instant::now();
705 interpreter
706 .run_session_with_callback(
707 &session,
708 TensorCallback::identity(),
709 TensorCallback::identity(),
710 false,
711 )
712 .unwrap();
713 let time = time.elapsed();
714 let time2 = std::time::Instant::now();
715 interpreter
716 .run_session_with_callback(
717 &session,
718 TensorCallback::identity(),
719 TensorCallback::identity(),
720 true,
721 )
722 .unwrap();
723 let time2 = time2.elapsed();
724 assert!((time - time2) > std::time::Duration::from_millis(50));
725}
726
727#[test]
728#[ignore = "Fails on CI"]
729fn try_to_drop_interpreter_before_session() {
730 let file = Path::new("tests/assets/realesr.mnn")
731 .canonicalize()
732 .unwrap();
733 let mut interpreter = Interpreter::from_file(&file).unwrap();
734 let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
735 drop(interpreter);
736 drop(session);
737}