1use mnn_sys::*;
2use std::{ffi::CString, mem::ManuallyDrop};
3
4use crate::{BackendConfig, prelude::*};
5
6#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub enum ForwardType {
39 All,
41 #[default]
42 Auto,
44 CPU,
46 #[cfg(feature = "metal")]
47 Metal,
49 #[cfg(feature = "opencl")]
50 OpenCL,
52 #[cfg(feature = "vulkan")]
54 Vulkan,
55 #[cfg(feature = "coreml")]
57 CoreML,
58}
59
60impl ForwardType {
61 fn to_mnn_sys(self) -> MNNForwardType {
63 match self {
64 ForwardType::Auto => MNNForwardType::MNN_FORWARD_AUTO,
65 ForwardType::All => MNNForwardType::MNN_FORWARD_ALL,
66 ForwardType::CPU => MNNForwardType::MNN_FORWARD_CPU,
67 #[cfg(feature = "metal")]
68 ForwardType::Metal => MNNForwardType::MNN_FORWARD_METAL,
69 #[cfg(feature = "opencl")]
70 ForwardType::OpenCL => MNNForwardType::MNN_FORWARD_OPENCL,
71 #[cfg(feature = "opengl")]
72 ForwardType::OpenGL => MNNForwardType::MNN_FORWARD_OPENGL,
73 #[cfg(feature = "vulkan")]
74 ForwardType::Vulkan => MNNForwardType::MNN_FORWARD_VULKAN,
75 #[cfg(feature = "coreml")]
76 ForwardType::CoreML => MNNForwardType::MNN_FORWARD_NN,
77 }
78 }
79
80 fn from_mnn_sys(mode: MNNForwardType) -> Self {
81 match mode {
82 MNNForwardType::MNN_FORWARD_AUTO => ForwardType::Auto,
83 MNNForwardType::MNN_FORWARD_ALL => ForwardType::All,
84 MNNForwardType::MNN_FORWARD_CPU => ForwardType::CPU,
85 #[cfg(feature = "metal")]
86 MNNForwardType::MNN_FORWARD_METAL => ForwardType::Metal,
87 #[cfg(feature = "opencl")]
88 MNNForwardType::MNN_FORWARD_OPENCL => ForwardType::OpenCL,
89 #[cfg(feature = "opengl")]
90 MNNForwardType::MNN_FORWARD_OPENGL => ForwardType::OpenGL,
91 #[cfg(feature = "vulkan")]
92 MNNForwardType::MNN_FORWARD_VULKAN => ForwardType::Vulkan,
93 #[cfg(feature = "coreml")]
94 MNNForwardType::MNN_FORWARD_NN => ForwardType::CoreML,
95 _ => ForwardType::Auto,
96 }
97 }
98
99 fn list() -> Vec<&'static str> {
101 vec![
102 "auto",
103 "all",
104 "cpu",
105 #[cfg(feature = "metal")]
106 "metal",
107 #[cfg(feature = "opencl")]
108 "opencl",
109 #[cfg(feature = "opengl")]
110 "opengl",
111 #[cfg(feature = "vulkan")]
112 "vulkan",
113 #[cfg(feature = "coreml")]
114 "coreml",
115 ]
116 }
117
118 pub fn to_str(self) -> &'static str {
120 match self {
121 ForwardType::Auto => "auto",
122 ForwardType::All => "all",
123 ForwardType::CPU => "cpu",
124 #[cfg(feature = "metal")]
125 ForwardType::Metal => "metal",
126 #[cfg(feature = "opencl")]
127 ForwardType::OpenCL => "opencl",
128 #[cfg(feature = "opengl")]
129 ForwardType::OpenGL => "opengl",
130 #[cfg(feature = "vulkan")]
131 ForwardType::Vulkan => "vulkan",
132 #[cfg(feature = "coreml")]
133 ForwardType::CoreML => "coreml",
134 }
135 }
136}
137
138impl core::str::FromStr for ForwardType {
139 type Err = MNNError;
140
141 fn from_str(s: &str) -> Result<Self, Self::Err> {
142 match s {
143 "auto" => Ok(ForwardType::Auto),
144 "all" => Ok(ForwardType::All),
145 "cpu" => Ok(ForwardType::CPU),
146 #[cfg(feature = "metal")]
147 "metal" => Ok(ForwardType::Metal),
148 #[cfg(feature = "opencl")]
149 "opencl" => Ok(ForwardType::OpenCL),
150 #[cfg(feature = "opengl")]
151 "opengl" => Ok(ForwardType::OpenGL),
152 #[cfg(feature = "vulkan")]
153 "vulkan" => Ok(ForwardType::Vulkan),
154 #[cfg(feature = "coreml")]
155 "coreml" => Ok(ForwardType::CoreML),
156 _ => Err(MNNError::new(crate::ErrorKind::ParseError)
157 .attach_printable(format!(
158 "Invalid ForwardType: {s}, maybe you might need to enable feature {s}"
159 ))
160 .attach_printable(format!(
161 "Valid ForwardType: {}",
162 ForwardType::list().join(", ")
163 ))),
164 }
165 }
166}
167
168pub struct ScheduleConfig {
212 pub(crate) inner: *mut MNNScheduleConfig,
213 pub(crate) backend_config: Option<BackendConfig>,
214 pub(crate) __marker: core::marker::PhantomData<()>,
215}
216
217impl core::fmt::Debug for ScheduleConfig {
218 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
219 f.debug_struct("ScheduleConfig")
220 .field("type", &self.get_type())
221 .field("backup_type", &self.get_backup_type())
222 .field("backend_config", &self.backend_config)
223 .finish()
224 }
225}
226
227#[cfg(feature = "serde")]
228impl serde::Serialize for ScheduleConfig {
229 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
230 use serde::ser::SerializeStruct;
231 let mut state = serializer.serialize_struct("ScheduleConfig", 3)?;
232 state.serialize_field("type", &self.get_type())?;
233 state.serialize_field("backup_type", &self.get_backup_type())?;
234 state.serialize_field("backend_config", &self.backend_config)?;
235 state.end()
236 }
237}
238
239impl Clone for ScheduleConfig {
240 fn clone(&self) -> Self {
241 unsafe {
242 let inner = mnnsc_clone(self.inner);
243 Self {
244 inner,
245 backend_config: self.backend_config.clone(),
246 __marker: core::marker::PhantomData,
247 }
248 }
249 }
250}
251
252impl Drop for ScheduleConfig {
253 fn drop(&mut self) {
254 unsafe {
255 mnn_sys::mnnsc_destroy(self.inner);
256 }
257 }
258}
259
260unsafe impl Send for ScheduleConfig {}
261
262impl Default for ScheduleConfig {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268impl ScheduleConfig {
269 pub fn as_ptr_mut(&mut self) -> *mut MNNScheduleConfig {
271 self.inner
272 }
273
274 pub fn new() -> Self {
276 unsafe {
277 let inner = mnnsc_create();
278 Self {
279 inner,
280 backend_config: None,
281 __marker: core::marker::PhantomData,
282 }
283 }
284 }
285
286 pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<&mut Self> {
296 let vec_cstring = save_tensors
297 .iter()
298 .map(|s| std::ffi::CString::new(*s).map_err(|e| error!(ErrorKind::AsciiError, e)))
299 .collect::<Result<Vec<_>>>()?;
300 let vec_cstr = vec_cstring
301 .iter()
302 .map(|s: &CString| s.as_c_str().as_ptr())
303 .collect::<Vec<_>>();
304 unsafe { mnnsc_set_save_tensors(self.inner, vec_cstr.as_ptr(), vec_cstr.len()) }
305 Ok(self)
306 }
307
308 pub fn set_type(&mut self, forward_type: ForwardType) -> &mut Self {
314 unsafe {
315 mnnsc_set_type(self.inner, forward_type.to_mnn_sys());
316 }
317 self
318 }
319
320 pub fn with_type(mut self, forward_type: ForwardType) -> Self {
322 self.set_type(forward_type);
323 self
324 }
325
326 pub fn get_type(&self) -> ForwardType {
328 unsafe { ForwardType::from_mnn_sys(mnnsc_get_type(self.inner)) }
329 }
330
331 pub fn set_num_threads(&mut self, num_threads: i32) -> &mut Self {
337 unsafe {
338 mnnsc_set_num_threads(self.inner, num_threads);
339 }
340 self
341 }
342
343 pub fn with_num_threads(mut self, num_threads: i32) -> Self {
345 self.set_num_threads(num_threads);
346 self
347 }
348
349 pub fn set_mode(&mut self, mode: i32) -> &mut Self {
355 unsafe {
356 mnnsc_set_mode(self.inner, mode);
357 }
358 self
359 }
360
361 pub fn with_mode(mut self, mode: i32) -> Self {
363 self.set_mode(mode);
364 self
365 }
366
367 pub fn set_backup_type(&mut self, backup_type: ForwardType) -> &mut Self {
373 unsafe {
374 mnnsc_set_backup_type(self.inner, backup_type.to_mnn_sys());
375 }
376 self
377 }
378
379 pub fn with_backup_type(mut self, backup_type: ForwardType) -> Self {
381 self.set_backup_type(backup_type);
382 self
383 }
384
385 pub fn get_backup_type(&self) -> ForwardType {
387 unsafe { ForwardType::from_mnn_sys(mnnsc_get_backup_type(self.inner)) }
388 }
389
390 pub fn set_backend_config(
396 &mut self,
397 backend_config: impl Into<Option<BackendConfig>>,
398 ) -> &mut Self {
399 self.backend_config = backend_config.into();
400 let ptr = if let Some(ref b) = self.backend_config {
401 b.inner
402 } else {
403 core::ptr::null_mut()
404 };
405 unsafe {
406 mnnsc_set_backend_config(self.inner, ptr);
407 }
408 self
409 }
410
411 pub fn with_backend_config(mut self, backend_config: impl Into<Option<BackendConfig>>) -> Self {
413 self.set_backend_config(backend_config);
414 self
415 }
416}
417
418#[derive(Debug)]
420pub struct ScheduleConfigs {
421 pub(crate) inner: Vec<*const MNNScheduleConfig>,
422 pub(crate) backend_configs: Vec<Option<BackendConfig>>,
423}
424
425impl Drop for ScheduleConfigs {
426 fn drop(&mut self) {
427 unsafe {
428 for i in self.inner.iter() {
429 mnnsc_destroy(*i.cast());
430 }
431 }
432 }
433}
434
435impl ScheduleConfigs {
436 pub fn push(&mut self, config: ScheduleConfig) {
438 let mut config = ManuallyDrop::new(config);
439 self.inner.push(config.inner);
440 self.backend_configs.push(config.backend_config.take());
441 }
442
443 pub fn with_capacity(capacity: usize) -> Self {
445 Self {
446 inner: Vec::with_capacity(capacity),
447 backend_configs: Vec::with_capacity(capacity),
448 }
449 }
450
451 pub const fn new() -> Self {
453 Self {
454 inner: Vec::new(),
455 backend_configs: Vec::new(),
456 }
457 }
458}
459
460impl Default for ScheduleConfigs {
461 fn default() -> Self {
462 Self::new()
463 }
464}
465
466impl FromIterator<ScheduleConfig> for ScheduleConfigs {
467 fn from_iter<T: IntoIterator<Item = ScheduleConfig>>(iter: T) -> Self {
468 let iter = iter.into_iter();
469 let mut ret = Self::with_capacity(iter.size_hint().1.unwrap_or_default());
470 iter.for_each(|item| {
471 ret.push(item);
472 });
473 ret
474 }
475}
476
477unsafe impl Send for ScheduleConfigs {}