mnn/
tensor.rs

1use crate::prelude::*;
2use core::marker::PhantomData;
3use mnn_sys::*;
4pub(crate) mod list;
5mod raw;
6pub use raw::RawTensor;
7
8use mnn_sys::HalideType;
9
10mod seal {
11    pub trait Sealed {}
12}
13macro_rules! seal {
14        ($($name:ty),*) => {
15            $(
16                impl<T> seal::Sealed for $name {}
17            )*
18        };
19    }
20seal!(Host<T>, Device<T>, Ref<'_, T>, RefMut<'_, T>);
21
22/// A trait to represent the type of a tensor
23pub trait TensorType: seal::Sealed {
24    /// The halide type of the tensor
25    type H;
26    /// Check if the tensor is owned
27    fn owned() -> bool;
28    /// Check if the tensor is borrowed
29    fn borrowed() -> bool {
30        !Self::owned()
31    }
32    /// Check if the tensor is allocated in the host
33    fn host() -> bool;
34    /// Check if the tensor is allocated in the device
35    fn device() -> bool {
36        !Self::host()
37    }
38}
39/// A tensor that is owned
40pub trait OwnedTensorType: TensorType {}
41/// A tensor that is borrowed
42pub trait RefTensorType: TensorType {}
43/// A tensor that is allocated in the cpu / host platform
44pub trait HostTensorType: TensorType {}
45/// A tensor that is allocated in the device / gpu platform
46pub trait DeviceTensorType: TensorType {}
47/// A tensor that is mutable
48pub trait MutableTensorType: TensorType {}
49
50impl<H: HalideType> TensorType for Host<H> {
51    type H = H;
52    fn owned() -> bool {
53        true
54    }
55    fn host() -> bool {
56        true
57    }
58}
59impl<H: HalideType> TensorType for Device<H> {
60    type H = H;
61    fn owned() -> bool {
62        true
63    }
64    fn host() -> bool {
65        false
66    }
67}
68
69impl<T: TensorType> TensorType for Ref<'_, T> {
70    type H = T::H;
71    fn owned() -> bool {
72        false
73    }
74    fn host() -> bool {
75        T::host()
76    }
77}
78
79impl<T: TensorType> TensorType for RefMut<'_, T> {
80    type H = T::H;
81    fn owned() -> bool {
82        false
83    }
84    fn host() -> bool {
85        T::host()
86    }
87}
88
89impl<H: HalideType> DeviceTensorType for Device<H> {}
90impl<H: HalideType> HostTensorType for Host<H> {}
91impl<H: HalideType> OwnedTensorType for Device<H> {}
92impl<H: HalideType> OwnedTensorType for Host<H> {}
93impl<T: DeviceTensorType> DeviceTensorType for Ref<'_, T> {}
94impl<T: DeviceTensorType> DeviceTensorType for RefMut<'_, T> {}
95impl<T: HostTensorType> HostTensorType for Ref<'_, T> {}
96impl<T: HostTensorType> HostTensorType for RefMut<'_, T> {}
97impl<T: OwnedTensorType> MutableTensorType for T {}
98impl<T: TensorType> MutableTensorType for RefMut<'_, T> {}
99impl<T: TensorType> RefTensorType for Ref<'_, T> {}
100impl<T: TensorType> RefTensorType for RefMut<'_, T> {}
101
102/// A tensor that is owned by the cpu / host platform
103pub struct Host<T = f32> {
104    pub(crate) __marker: PhantomData<T>,
105}
106/// A tensor that is owned by the device / gpu platform
107pub struct Device<T = f32> {
108    pub(crate) __marker: PhantomData<T>,
109}
110/// A reference to a any tensor
111pub struct Ref<'t, T> {
112    pub(crate) __marker: PhantomData<&'t [T]>,
113}
114
115/// A mutable reference to a any tensor
116pub struct RefMut<'t, T> {
117    pub(crate) __marker: PhantomData<&'t mut [T]>,
118}
119
120/// A generic tensor that can of host / device / owned / borrowed
121pub struct Tensor<T: TensorType> {
122    pub(crate) tensor: *mut mnn_sys::Tensor,
123    __marker: PhantomData<T>,
124}
125
126impl<T: TensorType> Drop for Tensor<T> {
127    fn drop(&mut self) {
128        if T::owned() {
129            unsafe {
130                mnn_sys::Tensor_destroy(self.tensor);
131            }
132        }
133    }
134}
135
136impl<H: HalideType> Tensor<Host<H>> {
137    /// Get's a reference to an owned host tensor
138    pub fn as_ref(&self) -> Tensor<Ref<'_, Host<H>>> {
139        Tensor {
140            tensor: self.tensor,
141            __marker: PhantomData,
142        }
143    }
144}
145
146impl<H: HalideType> Tensor<Device<H>> {
147    /// Get's a reference to an owned device tensor
148    pub fn as_ref(&self) -> Tensor<Ref<'_, Device<H>>> {
149        Tensor {
150            tensor: self.tensor,
151            __marker: PhantomData,
152        }
153    }
154}
155
156/// The type of the tensor dimension  
157/// If you are manually specifying the shapes then this doesn't really matter  
158/// N -> Batch size
159/// C -> Channel
160/// H -> Height
161/// W -> Width
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
163pub enum DimensionType {
164    /// Caffe style dimensions or NCHW
165    Caffe,
166    /// Caffe style dimensions with channel packed in 4 bytes or NC4HW4
167    CaffeC4,
168    /// Tensorflow style dimensions or NHWC
169    TensorFlow,
170}
171
172impl DimensionType {
173    /// Tensorflow style dimensions or NHWC
174    pub const NHWC: Self = Self::TensorFlow;
175    /// Caffe style dimensions or NCHW
176    pub const NCHW: Self = Self::Caffe;
177    /// Caffe style dimensions with channel packed in 4 bytes or NC4HW4
178    pub const NC4HW4: Self = Self::CaffeC4;
179    pub(crate) fn to_mnn_sys(self) -> mnn_sys::DimensionType {
180        match self {
181            DimensionType::Caffe => mnn_sys::DimensionType::CAFFE,
182            DimensionType::CaffeC4 => mnn_sys::DimensionType::CAFFE_C4,
183            DimensionType::TensorFlow => mnn_sys::DimensionType::TENSORFLOW,
184        }
185    }
186}
187
188impl From<mnn_sys::DimensionType> for DimensionType {
189    fn from(dm: mnn_sys::DimensionType) -> Self {
190        match dm {
191            mnn_sys::DimensionType::CAFFE => DimensionType::Caffe,
192            mnn_sys::DimensionType::CAFFE_C4 => DimensionType::CaffeC4,
193            mnn_sys::DimensionType::TENSORFLOW => DimensionType::TensorFlow,
194        }
195    }
196}
197
198impl<T: TensorType> Tensor<T>
199where
200    T::H: HalideType,
201{
202    /// This function constructs a Tensor type from a raw pointer
203    ///# Safety
204    /// Since this constructs a Tensor from a raw pointer we have no way to guarantee that it's a
205    /// valid tensor or it's lifetime
206    pub unsafe fn from_ptr(tensor: *mut mnn_sys::Tensor) -> Self {
207        assert!(!tensor.is_null());
208        Self {
209            tensor,
210            __marker: PhantomData,
211        }
212    }
213    /// Copies the data from a host tensor to the self tensor
214    pub fn copy_from_host_tensor(&mut self, tensor: &Tensor<Host<T::H>>) -> Result<()> {
215        let ret = unsafe { Tensor_copyFromHostTensor(self.tensor, tensor.tensor) };
216        crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
217        Ok(())
218    }
219
220    /// Copies the data from the self tensor to a host tensor
221    pub fn copy_to_host_tensor(&self, tensor: &mut Tensor<Host<T::H>>) -> Result<()> {
222        let ret = unsafe { Tensor_copyToHostTensor(self.tensor, tensor.tensor) };
223        crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
224        Ok(())
225    }
226
227    /// Get the device id of the tensor
228    pub fn device_id(&self) -> u64 {
229        unsafe { Tensor_deviceId(self.tensor) }
230    }
231
232    /// Get the shape of the tensor
233    pub fn shape(&self) -> TensorShape {
234        unsafe { Tensor_shape(self.tensor) }.into()
235    }
236
237    /// Get the dimensions of the tensor
238    pub fn dimensions(&self) -> usize {
239        unsafe { Tensor_dimensions(self.tensor) as usize }
240    }
241
242    /// Get the width of the tensor
243    pub fn width(&self) -> u32 {
244        unsafe { Tensor_width(self.tensor) as u32 }
245    }
246
247    /// Get the height of the tensor
248    pub fn height(&self) -> u32 {
249        unsafe { Tensor_height(self.tensor) as u32 }
250    }
251
252    /// Get the channel size of the tensor
253    pub fn channel(&self) -> u32 {
254        unsafe { Tensor_channel(self.tensor) as u32 }
255    }
256
257    /// Get the batch size of the tensor
258    pub fn batch(&self) -> u32 {
259        unsafe { Tensor_batch(self.tensor) as u32 }
260    }
261
262    /// Get the size of the tensor when counted by bytes
263    pub fn size(&self) -> usize {
264        unsafe { Tensor_usize(self.tensor) }
265    }
266
267    /// Get the size of the tensor when counted by elements
268    pub fn element_size(&self) -> usize {
269        unsafe { Tensor_elementSize(self.tensor) as usize }
270    }
271
272    /// Print the shape of the tensor
273    pub fn print_shape(&self) {
274        unsafe {
275            Tensor_printShape(self.tensor);
276        }
277    }
278
279    /// Print the tensor
280    pub fn print(&self) {
281        unsafe {
282            Tensor_print(self.tensor);
283        }
284    }
285
286    /// Check if the tensor is dynamic and needs resizing
287    pub fn is_dynamic_unsized(&self) -> bool {
288        self.shape().as_ref().contains(&-1)
289    }
290
291    /// DO not use this function directly
292    /// # Safety
293    /// This is just provided as a 1:1 compat mostly for possible later use
294    pub unsafe fn halide_buffer(&self) -> *const halide_buffer_t {
295        unsafe { Tensor_buffer(self.tensor) }
296    }
297
298    /// Do not use this function directly
299    /// # Safety
300    /// This is just provided as a 1:1 compat mostly for possible later use
301    pub unsafe fn halide_buffer_mut(&self) -> *mut halide_buffer_t {
302        unsafe { Tensor_buffer_mut(self.tensor) }
303    }
304
305    /// Get the dimension type of the tensor
306    pub fn get_dimension_type(&self) -> DimensionType {
307        debug_assert!(!self.tensor.is_null());
308        From::from(unsafe { Tensor_getDimensionType(self.tensor) })
309    }
310
311    /// Get the data type of the tensor
312    pub fn get_type(&self) -> mnn_sys::halide_type_t {
313        unsafe { Tensor_getType(self.tensor) }
314    }
315
316    /// Check if the tensor is of the specified data type
317    pub fn is_type_of<H: HalideType>(&self) -> bool {
318        let htc = halide_type_of::<H>();
319        unsafe { Tensor_isTypeOf(self.tensor, htc) }
320    }
321
322    /// # Safety
323    /// This is very unsafe do not use this unless you know what you are doing
324    pub unsafe fn into_raw(self) -> RawTensor<'static> {
325        let out = RawTensor {
326            inner: self.tensor,
327            __marker: PhantomData,
328        };
329        core::mem::forget(self);
330        out
331    }
332}
333impl<T: MutableTensorType> Tensor<T>
334where
335    T::H: HalideType,
336{
337    /// Fill the tensor with the specified value
338    pub fn fill(&mut self, value: T::H)
339    where
340        T::H: Copy,
341    {
342        if T::host() {
343            let size = self.element_size();
344            assert!(self.is_type_of::<T::H>());
345            let result: &mut [T::H] = unsafe {
346                let data = mnn_sys::Tensor_host_mut(self.tensor).cast();
347                core::slice::from_raw_parts_mut(data, size)
348            };
349            result.fill(value);
350        } else if T::device() {
351            let shape = self.shape();
352            let dm_type = self.get_dimension_type();
353            let mut host = Tensor::new(shape, dm_type);
354            host.fill(value);
355            self.copy_from_host_tensor(&host)
356                .expect("Failed to copy data from host tensor");
357        } else {
358            unreachable!()
359        }
360    }
361}
362
363impl<T: HostTensorType> Tensor<T>
364where
365    T::H: HalideType,
366{
367    /// Try to map the device tensor to the host memory and get the slice
368    pub fn try_host(&self) -> Result<&[T::H]> {
369        let size = self.element_size();
370        ensure!(
371            self.is_type_of::<T::H>(),
372            ErrorKind::HalideTypeMismatch {
373                got: std::any::type_name::<T::H>(),
374            }
375        );
376        let result = unsafe {
377            let data = mnn_sys::Tensor_host(self.tensor).cast();
378            core::slice::from_raw_parts(data, size)
379        };
380        Ok(result)
381    }
382
383    /// Try to map the device tensor to the host memory and get the mutable slice
384    pub fn try_host_mut(&mut self) -> Result<&mut [T::H]> {
385        let size = self.element_size();
386        ensure!(
387            self.is_type_of::<T::H>(),
388            ErrorKind::HalideTypeMismatch {
389                got: std::any::type_name::<T::H>(),
390            }
391        );
392
393        let result = unsafe {
394            let data: *mut T::H = mnn_sys::Tensor_host_mut(self.tensor).cast();
395            debug_assert!(!data.is_null());
396            core::slice::from_raw_parts_mut(data, size)
397        };
398        Ok(result)
399    }
400
401    /// Get the host memory slice of the tensor
402    pub fn host(&self) -> &[T::H] {
403        self.try_host().expect("Failed to get tensor host")
404    }
405
406    /// Get the mutable host memory slice of the tensor
407    pub fn host_mut(&mut self) -> &mut [T::H] {
408        self.try_host_mut().expect("Failed to get tensor host_mut")
409    }
410}
411
412impl<T: DeviceTensorType> Tensor<T>
413where
414    T::H: HalideType,
415{
416    /// Try to wait for the device tensor to finish processing
417    pub fn wait(&self, map_type: MapType, finish: bool) {
418        unsafe {
419            Tensor_wait(self.tensor, map_type, finish as i32);
420        }
421    }
422
423    /// Create a host tensor from the device tensor with same dimensions and data type and
424    /// optionally copy the data from the device tensor
425    pub fn create_host_tensor_from_device(&self, copy_data: bool) -> Tensor<Host<T::H>> {
426        let shape = self.shape();
427        let dm_type = self.get_dimension_type();
428        let mut out = Tensor::new(shape, dm_type);
429
430        if copy_data {
431            self.copy_to_host_tensor(&mut out)
432                .expect("Failed to copy data from device tensor");
433        }
434        out
435    }
436}
437
438impl<T: OwnedTensorType> Tensor<T>
439where
440    T::H: HalideType,
441{
442    /// Create a new tensor with the specified shape and dimension type
443    pub fn new(shape: impl AsTensorShape, dm_type: DimensionType) -> Self {
444        let shape = shape.as_tensor_shape();
445        let tensor = unsafe {
446            if T::device() {
447                Tensor_createDevice(
448                    shape.shape.as_ptr(),
449                    shape.size,
450                    halide_type_of::<T::H>(),
451                    dm_type.to_mnn_sys(),
452                )
453            } else {
454                Tensor_createWith(
455                    shape.shape.as_ptr(),
456                    shape.size,
457                    halide_type_of::<T::H>(),
458                    core::ptr::null_mut(),
459                    dm_type.to_mnn_sys(),
460                )
461            }
462        };
463        debug_assert!(!tensor.is_null());
464        Self {
465            tensor,
466            __marker: PhantomData,
467        }
468    }
469}
470
471impl<T: OwnedTensorType> Clone for Tensor<T>
472where
473    T::H: HalideType,
474{
475    fn clone(&self) -> Tensor<T> {
476        let tensor_ptr = unsafe { Tensor_clone(self.tensor) };
477        Self {
478            tensor: tensor_ptr,
479            __marker: PhantomData,
480        }
481    }
482}
483
484/// A tensor shape
485#[derive(Clone, Copy)]
486#[repr(C)]
487pub struct TensorShape {
488    pub(crate) shape: [i32; 4],
489    pub(crate) size: usize,
490}
491
492impl From<mnn_sys::TensorShape> for TensorShape {
493    fn from(value: mnn_sys::TensorShape) -> Self {
494        Self {
495            shape: value.shape,
496            size: value.size,
497        }
498    }
499}
500
501impl From<TensorShape> for mnn_sys::TensorShape {
502    fn from(value: TensorShape) -> Self {
503        Self {
504            shape: value.shape,
505            size: value.size,
506        }
507    }
508}
509
510impl core::ops::Deref for TensorShape {
511    type Target = [i32];
512
513    fn deref(&self) -> &Self::Target {
514        &self.shape[..self.size]
515    }
516}
517
518impl core::ops::Index<usize> for TensorShape {
519    type Output = i32;
520
521    fn index(&self, index: usize) -> &Self::Output {
522        &self.shape[..self.size][index]
523    }
524}
525
526impl core::ops::IndexMut<usize> for TensorShape {
527    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
528        &mut self.shape[..self.size][index]
529    }
530}
531
532impl core::ops::DerefMut for TensorShape {
533    fn deref_mut(&mut self) -> &mut Self::Target {
534        &mut self.shape[..self.size]
535    }
536}
537
538impl core::fmt::Debug for TensorShape {
539    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
540        write!(f, "{:?}", &self.shape[..self.size])
541    }
542}
543
544/// A trait to convert any array-like type to a tensor shape
545pub trait AsTensorShape {
546    /// Convert the array-like type to a tensor shape
547    fn as_tensor_shape(&self) -> TensorShape;
548}
549
550impl<T: AsRef<[i32]>> AsTensorShape for T {
551    fn as_tensor_shape(&self) -> TensorShape {
552        let this = self.as_ref();
553        let size = std::cmp::min(this.len(), 4);
554        let mut shape = [1; 4];
555        shape[..size].copy_from_slice(&this[..size]);
556        TensorShape { shape, size }
557    }
558}
559
560impl AsTensorShape for TensorShape {
561    fn as_tensor_shape(&self) -> TensorShape {
562        *self
563    }
564}
565
566#[cfg(test)]
567mod as_tensor_shape_tests {
568    use super::AsTensorShape;
569    macro_rules! shape_test {
570        ($t:ty, $kind: expr, $value: expr) => {
571            eprintln!("Testing {} with {} shape", stringify!($t), $kind);
572            $value.as_tensor_shape();
573        };
574    }
575    #[test]
576    fn as_tensor_shape_test_vec() {
577        shape_test!(Vec<i32>, "small", vec![1, 2, 3]);
578        shape_test!(Vec<i32>, "large", vec![12, 23, 34, 45, 67]);
579    }
580    #[test]
581    fn as_tensor_shape_test_array() {
582        shape_test!([i32; 3], "small", [1, 2, 3]);
583        shape_test!([i32; 5], "large", [12, 23, 34, 45, 67]);
584    }
585    #[test]
586    fn as_tensor_shape_test_ref() {
587        shape_test!(&[i32], "small", &[1, 2, 3]);
588        shape_test!(&[i32], "large", &[12, 23, 34, 45, 67]);
589    }
590}
591
592#[cfg(test)]
593mod tensor_tests {
594    #[test]
595    #[should_panic]
596    fn unsafe_nullptr_tensor() {
597        unsafe {
598            super::Tensor::<super::Host<i32>>::from_ptr(core::ptr::null_mut());
599        }
600    }
601}
602
603impl<T: HostTensorType + RefTensorType> Tensor<T>
604where
605    T::H: HalideType,
606{
607    /// Try to create a ref tensor from any array-like type
608    pub fn borrowed(shape: impl AsTensorShape, input: impl AsRef<[T::H]>) -> Self {
609        let shape = shape.as_tensor_shape();
610        let input = input.as_ref();
611        let tensor = unsafe {
612            Tensor_createWith(
613                shape.shape.as_ptr(),
614                shape.size,
615                halide_type_of::<T::H>(),
616                input.as_ptr().cast_mut().cast(),
617                DimensionType::Caffe.to_mnn_sys(),
618            )
619        };
620        debug_assert!(!tensor.is_null());
621        Self {
622            tensor,
623            __marker: PhantomData,
624        }
625    }
626
627    /// Try to create a mutable ref tensor from any array-like type
628    pub fn borrowed_mut(shape: impl AsTensorShape, mut input: impl AsMut<[T::H]>) -> Self {
629        let shape = shape.as_tensor_shape();
630        let input = input.as_mut();
631        let tensor = unsafe {
632            Tensor_createWith(
633                shape.shape.as_ptr(),
634                shape.size,
635                halide_type_of::<T::H>(),
636                input.as_mut_ptr().cast(),
637                DimensionType::Caffe.to_mnn_sys(),
638            )
639        };
640        debug_assert!(!tensor.is_null());
641        Self {
642            tensor,
643            __marker: PhantomData,
644        }
645    }
646}
647
648#[test]
649fn test_tensor_borrowed() {
650    let shape = [1, 2, 3];
651    let data = vec![1, 2, 3, 4, 5, 6];
652    let tensor = Tensor::<Ref<Host<i32>>>::borrowed(&shape, &data);
653    assert_eq!(tensor.shape().as_ref(), shape);
654    assert_eq!(tensor.host(), data.as_slice());
655}
656
657#[test]
658fn test_tensor_borrow_mut() {
659    let shape = [1, 2, 3];
660    let mut data = vec![1, 2, 3, 4, 5, 6];
661    let mut tensor = Tensor::<RefMut<Host<i32>>>::borrowed_mut(&shape, &mut data);
662    tensor.host_mut().fill(1);
663    assert_eq!(data, &[1, 1, 1, 1, 1, 1]);
664}