mnn/
error.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
use mnn_sys::ErrorCode;

#[doc(hidden)]
pub type Result<T, E = MNNError> = core::result::Result<T, E>;

/// Error type container for MNN
pub struct MNNError {
    kind: error_stack::Report<ErrorKind>,
}

impl core::fmt::Display for MNNError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{:?}", self.kind)
    }
}

impl core::fmt::Debug for MNNError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{:?}", self.kind)
    }
}

impl std::error::Error for MNNError {}
// pub type MNNError = error_stack::Report<ErrorKind>;

/// Error types for MNN
#[derive(thiserror::Error, Debug)]
pub enum ErrorKind {
    /// Internal error (from MNN library)
    #[error("Internal error: {0:?}")]
    InternalError(ErrorCode),
    /// Mismatching Size for input
    #[error("Invalid input: expected {expected}, got {got}")]
    SizeMismatch {
        /// Expected size
        expected: usize,
        /// Provided size
        got: usize,
    },
    /// Failed to copy tensor
    #[error("Failed to copy tensor")]
    TensorCopyFailed(i32),
    /// I/O Error
    #[error("IO Error")]
    IOError,
    /// Interpreter Error
    #[error("Interpreter Error")]
    InterpreterError,
    /// ASCII Error (path, name, etc had invalid characters)
    #[error("Ascii Error")]
    AsciiError,
    /// HalideType mismatch (e.g. trying to convert from a float tensor to an int tensor)
    #[error("HalideType mismatch: got {got}")]
    HalideTypeMismatch {
        /// HalideType that was
        got: &'static str,
    },
    /// Failed to parse the Argument
    #[error("Parse Error")]
    ParseError,
    /// Error with mnn-sync crate
    #[error("Sync Error")]
    SyncError,
    /// Error with some tensor
    #[error("Tensor Error")]
    TensorError,
    /// Tried to run a dynamic tensor without resizing it first
    #[error("Dynamic Tensor Error: Tensor needs to be resized before using")]
    DynamicTensorError,
}

impl MNNError {
    #[track_caller]
    #[doc(hidden)]
    pub fn new(kind: ErrorKind) -> Self {
        let kind = error_stack::Report::new(kind);
        Self { kind }
    }

    #[track_caller]
    pub(crate) fn from_error_code(code: ErrorCode) -> Self {
        Self::new(ErrorKind::InternalError(code))
    }

    /// Return the inner [error_stack::Report] containing the error
    #[inline(always)]
    pub fn into_inner(self) -> error_stack::Report<ErrorKind> {
        self.kind
    }
}

impl From<ErrorKind> for MNNError {
    #[track_caller]
    fn from(kind: ErrorKind) -> Self {
        Self::new(kind)
    }
}

macro_rules! ensure {
    ($cond:expr, $kind:expr) => {
        if !($cond) {
            return Err(crate::error::MNNError::new($kind));
        }
    };

    ($cond:expr, $kind:expr; $($printable:expr),*) => {
        if !($cond) {
            return Err(crate::error::MNNError::new($kind)
                $(.attach_printable($printable))*
            )
        }
    };


    ($cond:expr, $from:expr, $to:expr) => {
        if (!$cond) {
            return Err(error_stack::Report::new($from).change_context($to));
        }
    };
    ($cond:expr, $from:expr, $to:expr; $($printable:expr),*) => {
        if (!$cond) {
            return Err(error_stack::Report::new($from)
                .change_context($to)
                $(.attach_printable($printable))*
            )
        }
    };
}

macro_rules! error {
    ($kind:expr) => {
        crate::error::MNNError::new($kind)
    };
    ($kind:expr, $from:expr) => {
        crate::error::MNNError::from(error_stack::Report::new($from).change_context($kind))
    };
}

pub(crate) use ensure;
pub(crate) use error;

impl From<error_stack::Report<ErrorKind>> for MNNError {
    #[track_caller]
    fn from(report: error_stack::Report<ErrorKind>) -> Self {
        Self { kind: report }
    }
}

impl MNNError {
    pub(crate) fn attach_printable(
        self,
        printable: impl core::fmt::Display + core::fmt::Debug + Send + Sync + 'static,
    ) -> Self {
        let kind = self.kind.attach_printable(printable);
        Self { kind }
    }
}