mnn/
schedule.rs

1use mnn_sys::*;
2use std::{ffi::CString, mem::ManuallyDrop};
3
4use crate::{BackendConfig, prelude::*};
5
6/// Backend used for running the model
7///
8/// The `ForwardType` enum is used to specify the backend that will be used for forward computation
9/// in the MNN framework. Each variant corresponds to a different backend, which may be enabled
10/// or disabled based on the features enabled in the build configuration.
11///
12/// # Variants
13///
14/// - `All`: Use all available backends.
15/// - `Auto`: Automatically select the best backend based on the current environment and hardware.
16/// - `CPU`: Use the CPU for computation.
17/// - `Metal`: Use the Metal backend for computation (requires the `metal` feature).
18/// - `OpenCL`: Use the OpenCL backend for computation (requires the `opencl` feature).
19/// - `OpenGL`: Use the OpenGL backend for computation (requires the `opengl` feature).
20/// - `Vulkan`: Use the Vulkan backend for computation (requires the `vulkan` feature).
21/// - `CoreML`: Use the CoreML backend for computation (requires the `coreml` feature).
22///
23/// # Example
24///
25/// ```rust
26/// use mnn::schedule::ForwardType;
27///
28/// let forward_type = ForwardType::Auto;
29/// println!("Selected forward type: {:?}", forward_type);
30/// ```
31///
32/// # Note
33///
34/// The availability of certain variants depends on the features enabled during the build.
35/// For example, the `Metal` variant is only available if the `metal` feature is enabled.
36#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub enum ForwardType {
39    /// Use all available backends.
40    All,
41    #[default]
42    /// Try to automatically select the best backend based on the current environment and hardware.
43    Auto,
44    /// Use the CPU for computation.
45    CPU,
46    #[cfg(feature = "metal")]
47    /// Use the Metal backend for computation.
48    Metal,
49    #[cfg(feature = "opencl")]
50    /// Use the OpenCL backend for computation.
51    OpenCL,
52    /// Use the Vulkan backend for computation.
53    #[cfg(feature = "vulkan")]
54    Vulkan,
55    /// Use the CoreML backend for computation.
56    #[cfg(feature = "coreml")]
57    CoreML,
58}
59
60impl ForwardType {
61    /// Convert the `ForwardType` enum to the corresponding C++ `MNNForwardType` enum.
62    fn to_mnn_sys(self) -> MNNForwardType {
63        match self {
64            ForwardType::Auto => MNNForwardType::MNN_FORWARD_AUTO,
65            ForwardType::All => MNNForwardType::MNN_FORWARD_ALL,
66            ForwardType::CPU => MNNForwardType::MNN_FORWARD_CPU,
67            #[cfg(feature = "metal")]
68            ForwardType::Metal => MNNForwardType::MNN_FORWARD_METAL,
69            #[cfg(feature = "opencl")]
70            ForwardType::OpenCL => MNNForwardType::MNN_FORWARD_OPENCL,
71            #[cfg(feature = "opengl")]
72            ForwardType::OpenGL => MNNForwardType::MNN_FORWARD_OPENGL,
73            #[cfg(feature = "vulkan")]
74            ForwardType::Vulkan => MNNForwardType::MNN_FORWARD_VULKAN,
75            #[cfg(feature = "coreml")]
76            ForwardType::CoreML => MNNForwardType::MNN_FORWARD_NN,
77        }
78    }
79
80    fn from_mnn_sys(mode: MNNForwardType) -> Self {
81        match mode {
82            MNNForwardType::MNN_FORWARD_AUTO => ForwardType::Auto,
83            MNNForwardType::MNN_FORWARD_ALL => ForwardType::All,
84            MNNForwardType::MNN_FORWARD_CPU => ForwardType::CPU,
85            #[cfg(feature = "metal")]
86            MNNForwardType::MNN_FORWARD_METAL => ForwardType::Metal,
87            #[cfg(feature = "opencl")]
88            MNNForwardType::MNN_FORWARD_OPENCL => ForwardType::OpenCL,
89            #[cfg(feature = "opengl")]
90            MNNForwardType::MNN_FORWARD_OPENGL => ForwardType::OpenGL,
91            #[cfg(feature = "vulkan")]
92            MNNForwardType::MNN_FORWARD_VULKAN => ForwardType::Vulkan,
93            #[cfg(feature = "coreml")]
94            MNNForwardType::MNN_FORWARD_NN => ForwardType::CoreML,
95            _ => ForwardType::Auto,
96        }
97    }
98
99    /// List all available `ForwardType` variants as string slices.
100    fn list() -> Vec<&'static str> {
101        vec![
102            "auto",
103            "all",
104            "cpu",
105            #[cfg(feature = "metal")]
106            "metal",
107            #[cfg(feature = "opencl")]
108            "opencl",
109            #[cfg(feature = "opengl")]
110            "opengl",
111            #[cfg(feature = "vulkan")]
112            "vulkan",
113            #[cfg(feature = "coreml")]
114            "coreml",
115        ]
116    }
117
118    /// Convert the `ForwardType` enum to a string slice.
119    pub fn to_str(self) -> &'static str {
120        match self {
121            ForwardType::Auto => "auto",
122            ForwardType::All => "all",
123            ForwardType::CPU => "cpu",
124            #[cfg(feature = "metal")]
125            ForwardType::Metal => "metal",
126            #[cfg(feature = "opencl")]
127            ForwardType::OpenCL => "opencl",
128            #[cfg(feature = "opengl")]
129            ForwardType::OpenGL => "opengl",
130            #[cfg(feature = "vulkan")]
131            ForwardType::Vulkan => "vulkan",
132            #[cfg(feature = "coreml")]
133            ForwardType::CoreML => "coreml",
134        }
135    }
136}
137
138impl core::str::FromStr for ForwardType {
139    type Err = MNNError;
140
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s {
143            "auto" => Ok(ForwardType::Auto),
144            "all" => Ok(ForwardType::All),
145            "cpu" => Ok(ForwardType::CPU),
146            #[cfg(feature = "metal")]
147            "metal" => Ok(ForwardType::Metal),
148            #[cfg(feature = "opencl")]
149            "opencl" => Ok(ForwardType::OpenCL),
150            #[cfg(feature = "opengl")]
151            "opengl" => Ok(ForwardType::OpenGL),
152            #[cfg(feature = "vulkan")]
153            "vulkan" => Ok(ForwardType::Vulkan),
154            #[cfg(feature = "coreml")]
155            "coreml" => Ok(ForwardType::CoreML),
156            _ => Err(MNNError::new(crate::ErrorKind::ParseError)
157                .attach_printable(format!(
158                    "Invalid ForwardType: {s}, maybe you might need to enable feature {s}"
159                ))
160                .attach_printable(format!(
161                    "Valid ForwardType: {}",
162                    ForwardType::list().join(", ")
163                ))),
164        }
165    }
166}
167
168/// Configuration for scheduling the forward computation in MNN.
169///
170/// The `ScheduleConfig` struct is used to configure various parameters for scheduling the forward
171/// computation in the MNN framework. It allows setting the type of backend, the number of threads,
172/// the mode of computation, and other options.
173///
174/// # Example
175///
176/// ```rust
177/// use mnn::schedule::{ScheduleConfig, ForwardType};
178///
179/// let mut config = ScheduleConfig::new();
180/// config.set_type(ForwardType::Auto);
181/// config.set_num_threads(4);
182/// config.set_mode(0);
183/// ```
184///
185/// # Fields
186///
187/// - `inner`: A raw pointer to the underlying `MNNScheduleConfig` structure.
188/// - `backend_config`: Specifies backend-specific configurations.
189/// - `__marker`: A marker to ensure the struct is `!Send` by default.
190///
191/// # Methods
192///
193/// - `new() -> Self`: Creates a new `ScheduleConfig` with default settings.
194/// - `as_ptr_mut(&mut self) -> *mut MNNScheduleConfig`: Returns a mutable raw pointer to the underlying `MNNScheduleConfig`.
195/// - `set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<()>`: Sets the tensors to be saved during computation.
196/// - `set_type(&mut self, forward_type: ForwardType)`: Sets the type of backend to be used for computation.
197/// - `set_num_threads(&mut self, num_threads: i32)`: Sets the number of threads to be used for computation.
198/// - `set_mode(&mut self, mode: i32)`: Sets the mode of computation.
199/// - `set_backup_type(&mut self, backup_type: ForwardType)`: Sets the backup type of backend to be used if the primary backend fails.
200/// - `set_backend_config(&mut self, backend_config: impl Into<Option<BackendConfig>>)`: Sets the backend-specific configuration.
201///
202/// # Safety
203///
204/// The `ScheduleConfig` struct contains raw pointers and interacts with the underlying C API of MNN.
205/// Users should be cautious when using this struct to avoid undefined behavior.
206///
207/// # Warning
208///
209/// **Warning:** The `Drop` implementation for `ScheduleConfig` ensures that the underlying `MNNScheduleConfig`
210/// is properly destroyed when the struct goes out of scope. Users should not manually free the `inner` pointer.
211pub struct ScheduleConfig {
212    pub(crate) inner: *mut MNNScheduleConfig,
213    pub(crate) backend_config: Option<BackendConfig>,
214    pub(crate) __marker: core::marker::PhantomData<()>,
215}
216
217impl core::fmt::Debug for ScheduleConfig {
218    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
219        f.debug_struct("ScheduleConfig")
220            .field("type", &self.get_type())
221            .field("backup_type", &self.get_backup_type())
222            .field("backend_config", &self.backend_config)
223            .finish()
224    }
225}
226
227#[cfg(feature = "serde")]
228impl serde::Serialize for ScheduleConfig {
229    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
230        use serde::ser::SerializeStruct;
231        let mut state = serializer.serialize_struct("ScheduleConfig", 3)?;
232        state.serialize_field("type", &self.get_type())?;
233        state.serialize_field("backup_type", &self.get_backup_type())?;
234        state.serialize_field("backend_config", &self.backend_config)?;
235        state.end()
236    }
237}
238
239impl Clone for ScheduleConfig {
240    fn clone(&self) -> Self {
241        unsafe {
242            let inner = mnnsc_clone(self.inner);
243            Self {
244                inner,
245                backend_config: self.backend_config.clone(),
246                __marker: core::marker::PhantomData,
247            }
248        }
249    }
250}
251
252impl Drop for ScheduleConfig {
253    fn drop(&mut self) {
254        unsafe {
255            mnn_sys::mnnsc_destroy(self.inner);
256        }
257    }
258}
259
260unsafe impl Send for ScheduleConfig {}
261
262impl Default for ScheduleConfig {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268impl ScheduleConfig {
269    /// Returns a mutable raw pointer to the underlying `MNNScheduleConfig`.
270    pub fn as_ptr_mut(&mut self) -> *mut MNNScheduleConfig {
271        self.inner
272    }
273
274    /// Creates a new `ScheduleConfig` with default settings.
275    pub fn new() -> Self {
276        unsafe {
277            let inner = mnnsc_create();
278            Self {
279                inner,
280                backend_config: None,
281                __marker: core::marker::PhantomData,
282            }
283        }
284    }
285
286    /// Sets the tensors to be saved during computation.
287    ///
288    /// # Arguments
289    ///
290    /// - `save_tensors`: A slice of tensor names to be saved.
291    ///
292    /// # Errors
293    ///
294    /// Returns an error if any of the tensor names contain null bytes.
295    pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<&mut Self> {
296        let vec_cstring = save_tensors
297            .iter()
298            .map(|s| std::ffi::CString::new(*s).map_err(|e| error!(ErrorKind::AsciiError, e)))
299            .collect::<Result<Vec<_>>>()?;
300        let vec_cstr = vec_cstring
301            .iter()
302            .map(|s: &CString| s.as_c_str().as_ptr())
303            .collect::<Vec<_>>();
304        unsafe { mnnsc_set_save_tensors(self.inner, vec_cstr.as_ptr(), vec_cstr.len()) }
305        Ok(self)
306    }
307
308    /// Sets the type of backend to be used for computation.
309    ///
310    /// # Arguments
311    ///
312    /// - `forward_type`: The type of backend to be used.
313    pub fn set_type(&mut self, forward_type: ForwardType) -> &mut Self {
314        unsafe {
315            mnnsc_set_type(self.inner, forward_type.to_mnn_sys());
316        }
317        self
318    }
319
320    /// Sets the type of backend to be used for computation.
321    pub fn with_type(mut self, forward_type: ForwardType) -> Self {
322        self.set_type(forward_type);
323        self
324    }
325
326    /// Gets the type of backend to be used for computation.
327    pub fn get_type(&self) -> ForwardType {
328        unsafe { ForwardType::from_mnn_sys(mnnsc_get_type(self.inner)) }
329    }
330
331    /// Sets the number of threads to be used for computation.
332    ///
333    /// # Arguments
334    ///
335    /// - `num_threads`: The number of threads to be used.
336    pub fn set_num_threads(&mut self, num_threads: i32) -> &mut Self {
337        unsafe {
338            mnnsc_set_num_threads(self.inner, num_threads);
339        }
340        self
341    }
342
343    /// Sets the number of threads to be used for computation.
344    pub fn with_num_threads(mut self, num_threads: i32) -> Self {
345        self.set_num_threads(num_threads);
346        self
347    }
348
349    /// Sets the mode of computation.
350    ///
351    /// # Arguments
352    ///
353    /// - `mode`: The mode of computation.
354    pub fn set_mode(&mut self, mode: i32) -> &mut Self {
355        unsafe {
356            mnnsc_set_mode(self.inner, mode);
357        }
358        self
359    }
360
361    /// Sets the mode of computation.
362    pub fn with_mode(mut self, mode: i32) -> Self {
363        self.set_mode(mode);
364        self
365    }
366
367    /// Sets the backup type of backend to be used if the primary backend fails.
368    ///
369    /// # Arguments
370    ///
371    /// - `backup_type`: The backup type of backend to be used.
372    pub fn set_backup_type(&mut self, backup_type: ForwardType) -> &mut Self {
373        unsafe {
374            mnnsc_set_backup_type(self.inner, backup_type.to_mnn_sys());
375        }
376        self
377    }
378
379    /// Sets the backup type of backend to be used if the primary backend fails.
380    pub fn with_backup_type(mut self, backup_type: ForwardType) -> Self {
381        self.set_backup_type(backup_type);
382        self
383    }
384
385    /// Gets the backup type of backend to be used if the primary backend fails.
386    pub fn get_backup_type(&self) -> ForwardType {
387        unsafe { ForwardType::from_mnn_sys(mnnsc_get_backup_type(self.inner)) }
388    }
389
390    /// Sets the backend-specific configuration.
391    ///
392    /// # Arguments
393    ///
394    /// - `backend_config`: specifies additional backend-specific configurations.
395    pub fn set_backend_config(
396        &mut self,
397        backend_config: impl Into<Option<BackendConfig>>,
398    ) -> &mut Self {
399        self.backend_config = backend_config.into();
400        let ptr = if let Some(ref b) = self.backend_config {
401            b.inner
402        } else {
403            core::ptr::null_mut()
404        };
405        unsafe {
406            mnnsc_set_backend_config(self.inner, ptr);
407        }
408        self
409    }
410
411    /// Sets the backend-specific configuration.
412    pub fn with_backend_config(mut self, backend_config: impl Into<Option<BackendConfig>>) -> Self {
413        self.set_backend_config(backend_config);
414        self
415    }
416}
417
418/// A list of `ScheduleConfig` objects to be used for scheduling the forward computation in MNN.
419#[derive(Debug)]
420pub struct ScheduleConfigs {
421    pub(crate) inner: Vec<*const MNNScheduleConfig>,
422    pub(crate) backend_configs: Vec<Option<BackendConfig>>,
423}
424
425impl Drop for ScheduleConfigs {
426    fn drop(&mut self) {
427        unsafe {
428            for i in self.inner.iter() {
429                mnnsc_destroy(*i.cast());
430            }
431        }
432    }
433}
434
435impl ScheduleConfigs {
436    /// Pushed a new `ScheduleConfig` to the list of configurations.
437    pub fn push(&mut self, config: ScheduleConfig) {
438        let mut config = ManuallyDrop::new(config);
439        self.inner.push(config.inner);
440        self.backend_configs.push(config.backend_config.take());
441    }
442
443    /// Creates a new (empty) `ScheduleConfigs` with the specified capacity.
444    pub fn with_capacity(capacity: usize) -> Self {
445        Self {
446            inner: Vec::with_capacity(capacity),
447            backend_configs: Vec::with_capacity(capacity),
448        }
449    }
450
451    /// Creates a new (empty) `ScheduleConfigs` with default settings.
452    pub const fn new() -> Self {
453        Self {
454            inner: Vec::new(),
455            backend_configs: Vec::new(),
456        }
457    }
458}
459
460impl Default for ScheduleConfigs {
461    fn default() -> Self {
462        Self::new()
463    }
464}
465
466impl FromIterator<ScheduleConfig> for ScheduleConfigs {
467    fn from_iter<T: IntoIterator<Item = ScheduleConfig>>(iter: T) -> Self {
468        let iter = iter.into_iter();
469        let mut ret = Self::with_capacity(iter.size_hint().1.unwrap_or_default());
470        iter.for_each(|item| {
471            ret.push(item);
472        });
473        ret
474    }
475}
476
477unsafe impl Send for ScheduleConfigs {}