1use crate::prelude::*;
4use std::str::FromStr;
5
6use mnn_sys::*;
7
8#[repr(transparent)]
13pub struct BackendConfig {
14 pub(crate) inner: *mut MNNBackendConfig,
15 __marker: core::marker::PhantomData<()>,
16}
17
18impl core::fmt::Debug for BackendConfig {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("BackendConfig")
21 .field("memory", &self.get_memory_mode())
22 .field("power", &self.get_power_mode())
23 .field("precision", &self.get_precision_mode())
24 .finish()
25 }
26}
27
28#[cfg(feature = "serde")]
29impl serde::Serialize for BackendConfig {
30 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
31 where
32 S: serde::ser::Serializer,
33 {
34 use serde::ser::SerializeStruct;
35 let mut state = serializer.serialize_struct("BackendConfig", 3)?;
36 state.serialize_field("memory", &self.get_memory_mode())?;
37 state.serialize_field("power", &self.get_power_mode())?;
38 state.serialize_field("precision", &self.get_precision_mode())?;
39 state.end()
40 }
41}
42
43impl Clone for BackendConfig {
44 fn clone(&self) -> Self {
45 unsafe {
46 let inner = mnn_sys::mnnbc_clone(self.inner);
47 Self {
48 inner,
49 __marker: core::marker::PhantomData,
50 }
51 }
52 }
53}
54
55impl Drop for BackendConfig {
56 fn drop(&mut self) {
57 unsafe {
58 mnn_sys::mnnbc_destroy(self.inner);
59 }
60 }
61}
62
63impl Default for BackendConfig {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
71#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72pub enum PowerMode {
73 Low,
75 Normal,
77 High,
79}
80
81impl PowerMode {
82 fn to_mnn_sys(self) -> mnn_sys::PowerMode {
83 match self {
84 Self::Low => mnn_sys::PowerMode::Power_Low,
85 Self::Normal => mnn_sys::PowerMode::Power_Normal,
86 Self::High => mnn_sys::PowerMode::Power_High,
87 }
88 }
89
90 pub fn to_str(self) -> &'static str {
92 match self {
93 Self::Low => "low",
94 Self::Normal => "normal",
95 Self::High => "high",
96 }
97 }
98
99 fn from_mnn_sys(mode: mnn_sys::PowerMode) -> Self {
100 match mode {
101 mnn_sys::PowerMode::Power_Low => Self::Low,
102 mnn_sys::PowerMode::Power_Normal => Self::Normal,
103 mnn_sys::PowerMode::Power_High => Self::High,
104 _ => Self::Normal,
105 }
106 }
107}
108
109impl FromStr for PowerMode {
110 type Err = MNNError;
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 match s {
113 "low" => Ok(Self::Low),
114 "normal" => Ok(Self::Normal),
115 "high" => Ok(Self::High),
116 _ => {
117 Err(error!(ErrorKind::ParseError)
118 .attach_printable(format!("invalid power mode: {s}")))
119 }
120 }
121 }
122}
123
124impl FromStr for MemoryMode {
125 type Err = MNNError;
126 fn from_str(s: &str) -> Result<Self, Self::Err> {
127 match s {
128 "low" => Ok(Self::Low),
129 "normal" => Ok(Self::Normal),
130 "high" => Ok(Self::High),
131 _ => {
132 Err(error!(ErrorKind::ParseError)
133 .attach_printable(format!("invalid memory mode: {s}")))
134 }
135 }
136 }
137}
138
139impl FromStr for PrecisionMode {
140 type Err = MNNError;
141 fn from_str(s: &str) -> Result<Self, Self::Err> {
142 match s {
143 "low" => Ok(Self::Low),
144 "normal" => Ok(Self::Normal),
145 "high" => Ok(Self::High),
146 "low_bf16" => Ok(Self::LowBf16),
147 _ => Err(error!(ErrorKind::ParseError)
148 .attach_printable(format!("invalid precision mode: {s}"))),
149 }
150 }
151}
152
153#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
155#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
156pub enum MemoryMode {
157 Low,
159 Normal,
161 High,
163}
164
165impl MemoryMode {
166 fn to_mnn_sys(self) -> mnn_sys::MemoryMode {
167 match self {
168 Self::Low => mnn_sys::MemoryMode::Memory_Low,
169 Self::Normal => mnn_sys::MemoryMode::Memory_Normal,
170 Self::High => mnn_sys::MemoryMode::Memory_High,
171 }
172 }
173
174 pub fn to_str(self) -> &'static str {
176 match self {
177 Self::Low => "low",
178 Self::Normal => "normal",
179 Self::High => "high",
180 }
181 }
182
183 fn from_mnn_sys(mode: mnn_sys::MemoryMode) -> Self {
184 match mode {
185 mnn_sys::MemoryMode::Memory_Low => Self::Low,
186 mnn_sys::MemoryMode::Memory_Normal => Self::Normal,
187 mnn_sys::MemoryMode::Memory_High => Self::High,
188 _ => Self::Normal,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
195#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
196pub enum PrecisionMode {
197 Normal = 0,
199 High,
201 Low,
203 LowBf16,
205}
206impl PrecisionMode {
207 pub(crate) fn to_mnn_sys(self) -> mnn_sys::PrecisionMode {
208 match self {
209 Self::LowBf16 => mnn_sys::PrecisionMode::Precision_Low_BF16,
210 Self::Low => mnn_sys::PrecisionMode::Precision_Low,
211 Self::Normal => mnn_sys::PrecisionMode::Precision_Normal,
212 Self::High => mnn_sys::PrecisionMode::Precision_High,
213 }
214 }
215
216 pub fn to_str(self) -> &'static str {
218 match self {
219 Self::LowBf16 => "low_bf16",
220 Self::Low => "low",
221 Self::Normal => "normal",
222 Self::High => "high",
223 }
224 }
225
226 fn from_mnn_sys(mode: mnn_sys::PrecisionMode) -> Self {
227 match mode {
228 mnn_sys::PrecisionMode::Precision_Low_BF16 => Self::LowBf16,
229 mnn_sys::PrecisionMode::Precision_Low => Self::Low,
230 mnn_sys::PrecisionMode::Precision_Normal => Self::Normal,
231 mnn_sys::PrecisionMode::Precision_High => Self::High,
232 _ => Self::Normal,
233 }
234 }
235}
236
237impl BackendConfig {
238 pub fn new() -> Self {
240 unsafe {
241 let inner = mnnbc_create();
242 Self {
243 inner,
244 __marker: core::marker::PhantomData,
245 }
246 }
247 }
248
249 pub fn set_memory_mode(&mut self, mode: MemoryMode) {
251 unsafe {
252 mnn_sys::mnnbc_set_memory_mode(self.inner, mode.to_mnn_sys());
253 }
254 }
255
256 pub fn with_memory_mode(mut self, mode: MemoryMode) -> Self {
258 self.set_memory_mode(mode);
259 self
260 }
261
262 pub fn get_memory_mode(&self) -> MemoryMode {
264 unsafe { MemoryMode::from_mnn_sys(mnn_sys::mnnbc_get_memory_mode(self.inner)) }
265 }
266
267 pub fn set_power_mode(&mut self, mode: PowerMode) {
269 unsafe {
270 mnn_sys::mnnbc_set_power_mode(self.inner, mode.to_mnn_sys());
271 }
272 }
273
274 pub fn with_power_mode(mut self, mode: PowerMode) -> Self {
276 self.set_power_mode(mode);
277 self
278 }
279
280 pub fn get_power_mode(&self) -> PowerMode {
282 unsafe { PowerMode::from_mnn_sys(mnn_sys::mnnbc_get_power_mode(self.inner)) }
283 }
284
285 pub fn set_precision_mode(&mut self, mode: PrecisionMode) {
287 unsafe {
288 mnn_sys::mnnbc_set_precision_mode(self.inner, mode.to_mnn_sys());
289 }
290 }
291
292 pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
294 self.set_precision_mode(mode);
295 self
296 }
297
298 pub fn get_precision_mode(&self) -> PrecisionMode {
300 unsafe { PrecisionMode::from_mnn_sys(mnn_sys::mnnbc_get_precision_mode(self.inner)) }
301 }
302
303 pub fn set_flags(&mut self, flags: usize) {
306 unsafe {
307 mnn_sys::mnnbc_set_flags(self.inner, flags);
308 }
309 }
310
311 pub fn with_flags(mut self, flags: usize) -> Self {
313 self.set_flags(flags);
314 self
315 }
316
317 pub unsafe fn set_shared_context(&mut self, shared_context: *mut libc::c_void) {
321 unsafe {
322 mnn_sys::mnnbc_set_shared_context(self.inner, shared_context);
323 }
324 }
325
326 pub unsafe fn with_shared_context(mut self, shared_context: *mut libc::c_void) -> Self {
330 unsafe {
331 self.set_shared_context(shared_context);
332 }
333 self
334 }
335}
336
337#[test]
338fn test_backend_config() {
339 let mut config = BackendConfig::new();
340 config.set_memory_mode(MemoryMode::Low);
341 config.set_power_mode(PowerMode::Low);
342 config.set_precision_mode(PrecisionMode::Low);
343 let config = std::hint::black_box(config.clone());
344 assert_eq!(config.get_memory_mode(), MemoryMode::Low);
345 assert_eq!(config.get_power_mode(), PowerMode::Low);
346 assert_eq!(config.get_precision_mode(), PrecisionMode::Low);
347 let config = config
348 .with_memory_mode(MemoryMode::Normal)
349 .with_power_mode(PowerMode::Normal)
350 .with_precision_mode(PrecisionMode::Normal);
351 assert_eq!(config.get_memory_mode(), MemoryMode::Normal);
352 assert_eq!(config.get_power_mode(), PowerMode::Normal);
353 assert_eq!(config.get_precision_mode(), PrecisionMode::Normal);
354}