1use mnn_sys::*;
2use std::{ffi::CString, mem::ManuallyDrop};
3
4use crate::{BackendConfig, prelude::*};
5
6#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub enum ForwardType {
39    All,
41    #[default]
42    Auto,
44    CPU,
46    #[cfg(feature = "metal")]
47    Metal,
49    #[cfg(feature = "opencl")]
50    OpenCL,
52    #[cfg(feature = "vulkan")]
54    Vulkan,
55    #[cfg(feature = "coreml")]
57    CoreML,
58}
59
60impl ForwardType {
61    fn to_mnn_sys(self) -> MNNForwardType {
63        match self {
64            ForwardType::Auto => MNNForwardType::MNN_FORWARD_AUTO,
65            ForwardType::All => MNNForwardType::MNN_FORWARD_ALL,
66            ForwardType::CPU => MNNForwardType::MNN_FORWARD_CPU,
67            #[cfg(feature = "metal")]
68            ForwardType::Metal => MNNForwardType::MNN_FORWARD_METAL,
69            #[cfg(feature = "opencl")]
70            ForwardType::OpenCL => MNNForwardType::MNN_FORWARD_OPENCL,
71            #[cfg(feature = "opengl")]
72            ForwardType::OpenGL => MNNForwardType::MNN_FORWARD_OPENGL,
73            #[cfg(feature = "vulkan")]
74            ForwardType::Vulkan => MNNForwardType::MNN_FORWARD_VULKAN,
75            #[cfg(feature = "coreml")]
76            ForwardType::CoreML => MNNForwardType::MNN_FORWARD_NN,
77        }
78    }
79
80    fn from_mnn_sys(mode: MNNForwardType) -> Self {
81        match mode {
82            MNNForwardType::MNN_FORWARD_AUTO => ForwardType::Auto,
83            MNNForwardType::MNN_FORWARD_ALL => ForwardType::All,
84            MNNForwardType::MNN_FORWARD_CPU => ForwardType::CPU,
85            #[cfg(feature = "metal")]
86            MNNForwardType::MNN_FORWARD_METAL => ForwardType::Metal,
87            #[cfg(feature = "opencl")]
88            MNNForwardType::MNN_FORWARD_OPENCL => ForwardType::OpenCL,
89            #[cfg(feature = "opengl")]
90            MNNForwardType::MNN_FORWARD_OPENGL => ForwardType::OpenGL,
91            #[cfg(feature = "vulkan")]
92            MNNForwardType::MNN_FORWARD_VULKAN => ForwardType::Vulkan,
93            #[cfg(feature = "coreml")]
94            MNNForwardType::MNN_FORWARD_NN => ForwardType::CoreML,
95            _ => ForwardType::Auto,
96        }
97    }
98
99    fn list() -> Vec<&'static str> {
101        vec![
102            "auto",
103            "all",
104            "cpu",
105            #[cfg(feature = "metal")]
106            "metal",
107            #[cfg(feature = "opencl")]
108            "opencl",
109            #[cfg(feature = "opengl")]
110            "opengl",
111            #[cfg(feature = "vulkan")]
112            "vulkan",
113            #[cfg(feature = "coreml")]
114            "coreml",
115        ]
116    }
117
118    pub fn to_str(self) -> &'static str {
120        match self {
121            ForwardType::Auto => "auto",
122            ForwardType::All => "all",
123            ForwardType::CPU => "cpu",
124            #[cfg(feature = "metal")]
125            ForwardType::Metal => "metal",
126            #[cfg(feature = "opencl")]
127            ForwardType::OpenCL => "opencl",
128            #[cfg(feature = "opengl")]
129            ForwardType::OpenGL => "opengl",
130            #[cfg(feature = "vulkan")]
131            ForwardType::Vulkan => "vulkan",
132            #[cfg(feature = "coreml")]
133            ForwardType::CoreML => "coreml",
134        }
135    }
136}
137
138impl core::str::FromStr for ForwardType {
139    type Err = MNNError;
140
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s {
143            "auto" => Ok(ForwardType::Auto),
144            "all" => Ok(ForwardType::All),
145            "cpu" => Ok(ForwardType::CPU),
146            #[cfg(feature = "metal")]
147            "metal" => Ok(ForwardType::Metal),
148            #[cfg(feature = "opencl")]
149            "opencl" => Ok(ForwardType::OpenCL),
150            #[cfg(feature = "opengl")]
151            "opengl" => Ok(ForwardType::OpenGL),
152            #[cfg(feature = "vulkan")]
153            "vulkan" => Ok(ForwardType::Vulkan),
154            #[cfg(feature = "coreml")]
155            "coreml" => Ok(ForwardType::CoreML),
156            _ => Err(MNNError::new(crate::ErrorKind::ParseError)
157                .attach_printable(format!(
158                    "Invalid ForwardType: {s}, maybe you might need to enable feature {s}"
159                ))
160                .attach_printable(format!(
161                    "Valid ForwardType: {}",
162                    ForwardType::list().join(", ")
163                ))),
164        }
165    }
166}
167
168pub struct ScheduleConfig {
212    pub(crate) inner: *mut MNNScheduleConfig,
213    pub(crate) backend_config: Option<BackendConfig>,
214    pub(crate) __marker: core::marker::PhantomData<()>,
215}
216
217impl core::fmt::Debug for ScheduleConfig {
218    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
219        f.debug_struct("ScheduleConfig")
220            .field("type", &self.get_type())
221            .field("backup_type", &self.get_backup_type())
222            .field("backend_config", &self.backend_config)
223            .finish()
224    }
225}
226
227#[cfg(feature = "serde")]
228impl serde::Serialize for ScheduleConfig {
229    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
230        use serde::ser::SerializeStruct;
231        let mut state = serializer.serialize_struct("ScheduleConfig", 3)?;
232        state.serialize_field("type", &self.get_type())?;
233        state.serialize_field("backup_type", &self.get_backup_type())?;
234        state.serialize_field("backend_config", &self.backend_config)?;
235        state.end()
236    }
237}
238
239impl Clone for ScheduleConfig {
240    fn clone(&self) -> Self {
241        unsafe {
242            let inner = mnnsc_clone(self.inner);
243            Self {
244                inner,
245                backend_config: self.backend_config.clone(),
246                __marker: core::marker::PhantomData,
247            }
248        }
249    }
250}
251
252impl Drop for ScheduleConfig {
253    fn drop(&mut self) {
254        unsafe {
255            mnn_sys::mnnsc_destroy(self.inner);
256        }
257    }
258}
259
260unsafe impl Send for ScheduleConfig {}
261
262impl Default for ScheduleConfig {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268impl ScheduleConfig {
269    pub fn as_ptr_mut(&mut self) -> *mut MNNScheduleConfig {
271        self.inner
272    }
273
274    pub fn new() -> Self {
276        unsafe {
277            let inner = mnnsc_create();
278            Self {
279                inner,
280                backend_config: None,
281                __marker: core::marker::PhantomData,
282            }
283        }
284    }
285
286    pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<&mut Self> {
296        let vec_cstring = save_tensors
297            .iter()
298            .map(|s| std::ffi::CString::new(*s).map_err(|e| error!(ErrorKind::AsciiError, e)))
299            .collect::<Result<Vec<_>>>()?;
300        let vec_cstr = vec_cstring
301            .iter()
302            .map(|s: &CString| s.as_c_str().as_ptr())
303            .collect::<Vec<_>>();
304        unsafe { mnnsc_set_save_tensors(self.inner, vec_cstr.as_ptr(), vec_cstr.len()) }
305        Ok(self)
306    }
307
308    pub fn set_type(&mut self, forward_type: ForwardType) -> &mut Self {
314        unsafe {
315            mnnsc_set_type(self.inner, forward_type.to_mnn_sys());
316        }
317        self
318    }
319
320    pub fn with_type(mut self, forward_type: ForwardType) -> Self {
322        self.set_type(forward_type);
323        self
324    }
325
326    pub fn get_type(&self) -> ForwardType {
328        unsafe { ForwardType::from_mnn_sys(mnnsc_get_type(self.inner)) }
329    }
330
331    pub fn set_num_threads(&mut self, num_threads: i32) -> &mut Self {
337        unsafe {
338            mnnsc_set_num_threads(self.inner, num_threads);
339        }
340        self
341    }
342
343    pub fn with_num_threads(mut self, num_threads: i32) -> Self {
345        self.set_num_threads(num_threads);
346        self
347    }
348
349    pub fn set_mode(&mut self, mode: i32) -> &mut Self {
355        unsafe {
356            mnnsc_set_mode(self.inner, mode);
357        }
358        self
359    }
360
361    pub fn with_mode(mut self, mode: i32) -> Self {
363        self.set_mode(mode);
364        self
365    }
366
367    pub fn set_backup_type(&mut self, backup_type: ForwardType) -> &mut Self {
373        unsafe {
374            mnnsc_set_backup_type(self.inner, backup_type.to_mnn_sys());
375        }
376        self
377    }
378
379    pub fn with_backup_type(mut self, backup_type: ForwardType) -> Self {
381        self.set_backup_type(backup_type);
382        self
383    }
384
385    pub fn get_backup_type(&self) -> ForwardType {
387        unsafe { ForwardType::from_mnn_sys(mnnsc_get_backup_type(self.inner)) }
388    }
389
390    pub fn set_backend_config(
396        &mut self,
397        backend_config: impl Into<Option<BackendConfig>>,
398    ) -> &mut Self {
399        self.backend_config = backend_config.into();
400        let ptr = if let Some(ref b) = self.backend_config {
401            b.inner
402        } else {
403            core::ptr::null_mut()
404        };
405        unsafe {
406            mnnsc_set_backend_config(self.inner, ptr);
407        }
408        self
409    }
410
411    pub fn with_backend_config(mut self, backend_config: impl Into<Option<BackendConfig>>) -> Self {
413        self.set_backend_config(backend_config);
414        self
415    }
416}
417
418#[derive(Debug)]
420pub struct ScheduleConfigs {
421    pub(crate) inner: Vec<*const MNNScheduleConfig>,
422    pub(crate) backend_configs: Vec<Option<BackendConfig>>,
423}
424
425impl Drop for ScheduleConfigs {
426    fn drop(&mut self) {
427        unsafe {
428            for i in self.inner.iter() {
429                mnnsc_destroy(*i.cast());
430            }
431        }
432    }
433}
434
435impl ScheduleConfigs {
436    pub fn push(&mut self, config: ScheduleConfig) {
438        let mut config = ManuallyDrop::new(config);
439        self.inner.push(config.inner);
440        self.backend_configs.push(config.backend_config.take());
441    }
442
443    pub fn with_capacity(capacity: usize) -> Self {
445        Self {
446            inner: Vec::with_capacity(capacity),
447            backend_configs: Vec::with_capacity(capacity),
448        }
449    }
450
451    pub const fn new() -> Self {
453        Self {
454            inner: Vec::new(),
455            backend_configs: Vec::new(),
456        }
457    }
458}
459
460impl Default for ScheduleConfigs {
461    fn default() -> Self {
462        Self::new()
463    }
464}
465
466impl FromIterator<ScheduleConfig> for ScheduleConfigs {
467    fn from_iter<T: IntoIterator<Item = ScheduleConfig>>(iter: T) -> Self {
468        let iter = iter.into_iter();
469        let mut ret = Self::with_capacity(iter.size_hint().1.unwrap_or_default());
470        iter.for_each(|item| {
471            ret.push(item);
472        });
473        ret
474    }
475}
476
477unsafe impl Send for ScheduleConfigs {}