mnn/
backend.rs

1//! The backend module contains the data types for the backend configuration
2
3use crate::prelude::*;
4use std::str::FromStr;
5
6use mnn_sys::*;
7
8/// BackendConfig is a struct that holds the configuration for the backend
9/// memory: [MemoryMode]
10/// power: [PowerMode]
11/// precision: [PrecisionMode]
12#[repr(transparent)]
13pub struct BackendConfig {
14    pub(crate) inner: *mut MNNBackendConfig,
15    __marker: core::marker::PhantomData<()>,
16}
17
18impl core::fmt::Debug for BackendConfig {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("BackendConfig")
21            .field("memory", &self.get_memory_mode())
22            .field("power", &self.get_power_mode())
23            .field("precision", &self.get_precision_mode())
24            .finish()
25    }
26}
27
28#[cfg(feature = "serde")]
29impl serde::Serialize for BackendConfig {
30    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
31    where
32        S: serde::ser::Serializer,
33    {
34        use serde::ser::SerializeStruct;
35        let mut state = serializer.serialize_struct("BackendConfig", 3)?;
36        state.serialize_field("memory", &self.get_memory_mode())?;
37        state.serialize_field("power", &self.get_power_mode())?;
38        state.serialize_field("precision", &self.get_precision_mode())?;
39        state.end()
40    }
41}
42
43impl Clone for BackendConfig {
44    fn clone(&self) -> Self {
45        unsafe {
46            let inner = mnn_sys::mnnbc_clone(self.inner);
47            Self {
48                inner,
49                __marker: core::marker::PhantomData,
50            }
51        }
52    }
53}
54
55impl Drop for BackendConfig {
56    fn drop(&mut self) {
57        unsafe {
58            mnn_sys::mnnbc_destroy(self.inner);
59        }
60    }
61}
62
63impl Default for BackendConfig {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69/// PowerModes depend on if the specific backend has support for it
70#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
71#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72pub enum PowerMode {
73    /// Low power mode
74    Low,
75    /// Normal power mode
76    Normal,
77    /// High power mode
78    High,
79}
80
81impl PowerMode {
82    fn to_mnn_sys(self) -> mnn_sys::PowerMode {
83        match self {
84            Self::Low => mnn_sys::PowerMode::Power_Low,
85            Self::Normal => mnn_sys::PowerMode::Power_Normal,
86            Self::High => mnn_sys::PowerMode::Power_High,
87        }
88    }
89
90    /// Returns a string representation of the power mode
91    pub fn to_str(self) -> &'static str {
92        match self {
93            Self::Low => "low",
94            Self::Normal => "normal",
95            Self::High => "high",
96        }
97    }
98
99    fn from_mnn_sys(mode: mnn_sys::PowerMode) -> Self {
100        match mode {
101            mnn_sys::PowerMode::Power_Low => Self::Low,
102            mnn_sys::PowerMode::Power_Normal => Self::Normal,
103            mnn_sys::PowerMode::Power_High => Self::High,
104            _ => Self::Normal,
105        }
106    }
107}
108
109impl FromStr for PowerMode {
110    type Err = MNNError;
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        match s {
113            "low" => Ok(Self::Low),
114            "normal" => Ok(Self::Normal),
115            "high" => Ok(Self::High),
116            _ => {
117                Err(error!(ErrorKind::ParseError)
118                    .attach_printable(format!("invalid power mode: {s}")))
119            }
120        }
121    }
122}
123
124impl FromStr for MemoryMode {
125    type Err = MNNError;
126    fn from_str(s: &str) -> Result<Self, Self::Err> {
127        match s {
128            "low" => Ok(Self::Low),
129            "normal" => Ok(Self::Normal),
130            "high" => Ok(Self::High),
131            _ => {
132                Err(error!(ErrorKind::ParseError)
133                    .attach_printable(format!("invalid memory mode: {s}")))
134            }
135        }
136    }
137}
138
139impl FromStr for PrecisionMode {
140    type Err = MNNError;
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s {
143            "low" => Ok(Self::Low),
144            "normal" => Ok(Self::Normal),
145            "high" => Ok(Self::High),
146            "low_bf16" => Ok(Self::LowBf16),
147            _ => Err(error!(ErrorKind::ParseError)
148                .attach_printable(format!("invalid precision mode: {s}"))),
149        }
150    }
151}
152
153/// MemoryModes depend on if the specific backend has support for it
154#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
155#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
156pub enum MemoryMode {
157    /// Low memory mode
158    Low,
159    /// Normal memory mode
160    Normal,
161    /// High memory mode
162    High,
163}
164
165impl MemoryMode {
166    fn to_mnn_sys(self) -> mnn_sys::MemoryMode {
167        match self {
168            Self::Low => mnn_sys::MemoryMode::Memory_Low,
169            Self::Normal => mnn_sys::MemoryMode::Memory_Normal,
170            Self::High => mnn_sys::MemoryMode::Memory_High,
171        }
172    }
173
174    /// Returns a string representation of the memory mode
175    pub fn to_str(self) -> &'static str {
176        match self {
177            Self::Low => "low",
178            Self::Normal => "normal",
179            Self::High => "high",
180        }
181    }
182
183    fn from_mnn_sys(mode: mnn_sys::MemoryMode) -> Self {
184        match mode {
185            mnn_sys::MemoryMode::Memory_Low => Self::Low,
186            mnn_sys::MemoryMode::Memory_Normal => Self::Normal,
187            mnn_sys::MemoryMode::Memory_High => Self::High,
188            _ => Self::Normal,
189        }
190    }
191}
192
193/// PrecisionModes depend on if the specific backend has support for it
194#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
195#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
196pub enum PrecisionMode {
197    /// Normal precision mode
198    Normal = 0,
199    /// High precision mode
200    High,
201    /// Low precision mode
202    Low,
203    /// Low precision mode with BF16
204    LowBf16,
205}
206impl PrecisionMode {
207    pub(crate) fn to_mnn_sys(self) -> mnn_sys::PrecisionMode {
208        match self {
209            Self::LowBf16 => mnn_sys::PrecisionMode::Precision_Low_BF16,
210            Self::Low => mnn_sys::PrecisionMode::Precision_Low,
211            Self::Normal => mnn_sys::PrecisionMode::Precision_Normal,
212            Self::High => mnn_sys::PrecisionMode::Precision_High,
213        }
214    }
215
216    /// Returns a string representation of the precision mode
217    pub fn to_str(self) -> &'static str {
218        match self {
219            Self::LowBf16 => "low_bf16",
220            Self::Low => "low",
221            Self::Normal => "normal",
222            Self::High => "high",
223        }
224    }
225
226    fn from_mnn_sys(mode: mnn_sys::PrecisionMode) -> Self {
227        match mode {
228            mnn_sys::PrecisionMode::Precision_Low_BF16 => Self::LowBf16,
229            mnn_sys::PrecisionMode::Precision_Low => Self::Low,
230            mnn_sys::PrecisionMode::Precision_Normal => Self::Normal,
231            mnn_sys::PrecisionMode::Precision_High => Self::High,
232            _ => Self::Normal,
233        }
234    }
235}
236
237impl BackendConfig {
238    /// Create a new backend config
239    pub fn new() -> Self {
240        unsafe {
241            let inner = mnnbc_create();
242            Self {
243                inner,
244                __marker: core::marker::PhantomData,
245            }
246        }
247    }
248
249    /// Sets the [MemoryMode] for the backend
250    pub fn set_memory_mode(&mut self, mode: MemoryMode) {
251        unsafe {
252            mnn_sys::mnnbc_set_memory_mode(self.inner, mode.to_mnn_sys());
253        }
254    }
255
256    /// Sets the [MemoryMode] for the backend
257    pub fn with_memory_mode(mut self, mode: MemoryMode) -> Self {
258        self.set_memory_mode(mode);
259        self
260    }
261
262    /// Gets the [MemoryMode] for the backend
263    pub fn get_memory_mode(&self) -> MemoryMode {
264        unsafe { MemoryMode::from_mnn_sys(mnn_sys::mnnbc_get_memory_mode(self.inner)) }
265    }
266
267    /// Sets the [PowerMode] for the backend
268    pub fn set_power_mode(&mut self, mode: PowerMode) {
269        unsafe {
270            mnn_sys::mnnbc_set_power_mode(self.inner, mode.to_mnn_sys());
271        }
272    }
273
274    /// Sets the [PowerMode] for the backend
275    pub fn with_power_mode(mut self, mode: PowerMode) -> Self {
276        self.set_power_mode(mode);
277        self
278    }
279
280    /// Gets the [PowerMode] for the backend
281    pub fn get_power_mode(&self) -> PowerMode {
282        unsafe { PowerMode::from_mnn_sys(mnn_sys::mnnbc_get_power_mode(self.inner)) }
283    }
284
285    /// Sets the [PrecisionMode] for the backend
286    pub fn set_precision_mode(&mut self, mode: PrecisionMode) {
287        unsafe {
288            mnn_sys::mnnbc_set_precision_mode(self.inner, mode.to_mnn_sys());
289        }
290    }
291
292    /// Sets the [PrecisionMode] for the backend
293    pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
294        self.set_precision_mode(mode);
295        self
296    }
297
298    /// Gets the [PrecisionMode] for the backend
299    pub fn get_precision_mode(&self) -> PrecisionMode {
300        unsafe { PrecisionMode::from_mnn_sys(mnn_sys::mnnbc_get_precision_mode(self.inner)) }
301    }
302
303    /// Sets the flags for the backend
304    /// What the flag represents is depends on each backend or isn't documented
305    pub fn set_flags(&mut self, flags: usize) {
306        unsafe {
307            mnn_sys::mnnbc_set_flags(self.inner, flags);
308        }
309    }
310
311    /// Sets the flags for the backend
312    pub fn with_flags(mut self, flags: usize) -> Self {
313        self.set_flags(flags);
314        self
315    }
316
317    /// # Safety
318    /// This just binds to the underlying unsafe api and should be used only if you know what you
319    /// are doing
320    pub unsafe fn set_shared_context(&mut self, shared_context: *mut libc::c_void) {
321        unsafe {
322            mnn_sys::mnnbc_set_shared_context(self.inner, shared_context);
323        }
324    }
325
326    /// # Safety
327    /// This just binds to the underlying unsafe api and should be used only if you know what you
328    /// are doing
329    pub unsafe fn with_shared_context(mut self, shared_context: *mut libc::c_void) -> Self {
330        unsafe {
331            self.set_shared_context(shared_context);
332        }
333        self
334    }
335}
336
337#[test]
338fn test_backend_config() {
339    let mut config = BackendConfig::new();
340    config.set_memory_mode(MemoryMode::Low);
341    config.set_power_mode(PowerMode::Low);
342    config.set_precision_mode(PrecisionMode::Low);
343    let config = std::hint::black_box(config.clone());
344    assert_eq!(config.get_memory_mode(), MemoryMode::Low);
345    assert_eq!(config.get_power_mode(), PowerMode::Low);
346    assert_eq!(config.get_precision_mode(), PrecisionMode::Low);
347    let config = config
348        .with_memory_mode(MemoryMode::Normal)
349        .with_power_mode(PowerMode::Normal)
350        .with_precision_mode(PrecisionMode::Normal);
351    assert_eq!(config.get_memory_mode(), MemoryMode::Normal);
352    assert_eq!(config.get_power_mode(), PowerMode::Normal);
353    assert_eq!(config.get_precision_mode(), PrecisionMode::Normal);
354}