mnn/tensor/
raw.rs

1use crate::prelude::*;
2use core::marker::PhantomData;
3use mnn_sys::HalideType;
4/// A raw tensor type that doesn't have any guarantees
5/// and will be unconditionally dropped
6#[repr(transparent)]
7pub struct RawTensor<'r> {
8    pub(crate) inner: *mut mnn_sys::Tensor,
9    pub(crate) __marker: PhantomData<&'r ()>,
10}
11
12// impl<'r> core::ops::Drop for RawTensor<'r> {
13//     fn drop(&mut self) {
14//         unsafe {
15//             mnn_sys::Tensor_destroy(self.inner);
16//         }
17//     }
18// }
19
20impl RawTensor<'_> {
21    /// Creates a new host tensor from the device tensor
22    pub fn create_host_tensor_from_device(&self, copy_data: bool) -> RawTensor<'static> {
23        let tensor =
24            unsafe { mnn_sys::Tensor_createHostTensorFromDevice(self.inner, copy_data as i32) };
25        // crate::ensure!(!tensor.is_null(), ErrorKind::TensorError);
26        assert!(!tensor.is_null());
27        RawTensor {
28            inner: tensor,
29            __marker: PhantomData,
30        }
31    }
32
33    /// Copies the data from a host tensor to the self tensor
34    pub fn copy_from_host_tensor(&mut self, tensor: &RawTensor) -> Result<()> {
35        let ret = unsafe { mnn_sys::Tensor_copyFromHostTensor(self.inner, tensor.inner) };
36        crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
37        Ok(())
38    }
39
40    /// Copies the data from the self tensor to a host tensor
41    pub fn copy_to_host_tensor(&self, tensor: &mut RawTensor) -> Result<()> {
42        let ret = unsafe { mnn_sys::Tensor_copyToHostTensor(self.inner, tensor.inner) };
43        crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
44        Ok(())
45    }
46
47    /// Returns the shape of the tensor
48    pub fn shape(&self) -> crate::TensorShape {
49        unsafe { mnn_sys::Tensor_shape(self.inner) }.into()
50    }
51
52    /// Returns the dimension type of the tensor
53    pub fn get_dimension_type(&self) -> super::DimensionType {
54        debug_assert!(!self.inner.is_null());
55        From::from(unsafe { mnn_sys::Tensor_getDimensionType(self.inner) })
56    }
57
58    /// Cleans up the tensor by calling the destructor of the tensor
59    pub fn destroy(self) {
60        unsafe {
61            mnn_sys::Tensor_destroy(self.inner);
62        }
63    }
64
65    /// Returns the size of the tensor when counted by bytes
66    pub fn size(&self) -> usize {
67        unsafe { mnn_sys::Tensor_usize(self.inner) }
68    }
69
70    /// Returns the size of the tensor when counted by elements
71    pub fn element_size(&self) -> usize {
72        unsafe { mnn_sys::Tensor_elementSize(self.inner) as usize }
73    }
74
75    /// Returns the number of dimensions of the tensor
76    pub fn dimensions(&self) -> usize {
77        unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize }
78    }
79
80    /// Returns the width of the tensor
81    pub fn width(&self) -> u32 {
82        unsafe { mnn_sys::Tensor_width(self.inner) as u32 }
83    }
84
85    /// Returns the height of the tensor
86    pub fn height(&self) -> u32 {
87        unsafe { mnn_sys::Tensor_height(self.inner) as u32 }
88    }
89
90    /// Returns the channel of the tensor
91    pub fn channel(&self) -> u32 {
92        unsafe { mnn_sys::Tensor_channel(self.inner) as u32 }
93    }
94
95    /// Returns true if the tensor is unsized and dynamic (needs to be resized to work)
96    pub fn is_dynamic_unsized(&self) -> bool {
97        self.shape().as_ref().contains(&-1)
98    }
99
100    /// Waits for the tensor to be ready
101    pub fn wait(&self, map_type: MapType, finish: bool) {
102        unsafe {
103            mnn_sys::Tensor_wait(self.inner, map_type, finish as i32);
104        }
105    }
106
107    /// # Safety
108    /// This is very unsafe do not use this unless you know what you are doing
109    /// Gives a raw pointer to the tensor's data
110    /// P.S. I don't know what I'm doing
111    pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void {
112        debug_assert!(!self.inner.is_null());
113        let data = unsafe { mnn_sys::Tensor_host_mut(self.inner) };
114        debug_assert!(!data.is_null());
115        data
116    }
117
118    /// # Safety
119    /// This is very unsafe do not use this unless you know what you are doing
120    /// Gives a mutable byte slice to the tensor's data
121    pub unsafe fn unchecked_host_bytes(&mut self) -> &mut [u8] {
122        unsafe { core::slice::from_raw_parts_mut(self.unchecked_host_ptr().cast(), self.size()) }
123    }
124
125    /// # Safety
126    /// This is very unsafe do not use this unless you know what you are doing
127    pub unsafe fn to_concrete<T: super::TensorType>(self) -> super::Tensor<T>
128    where
129        T::H: HalideType,
130    {
131        unsafe { super::Tensor::from_ptr(self.inner) }
132    }
133
134    pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self {
135        Self {
136            inner,
137            __marker: PhantomData,
138        }
139    }
140}