mnn/
tensor.rs

1use crate::prelude::*;
2use core::marker::PhantomData;
3use mnn_sys::*;
4use std::borrow::Borrow;
5pub(crate) mod list;
6mod raw;
7pub use raw::RawTensor;
8
9use mnn_sys::HalideType;
10
11mod seal {
12    pub trait Sealed {}
13}
14macro_rules! seal {
15        ($($name:ty),*) => {
16            $(
17                impl<T> seal::Sealed for $name {}
18            )*
19        };
20    }
21seal!(Host<T>, Device<T>, Ref<'_, T>, RefMut<'_, T>);
22
23/// A trait to represent the type of a tensor
24pub trait TensorType: seal::Sealed {
25    /// The halide type of the tensor
26    type H;
27    /// Check if the tensor is owned
28    fn owned() -> bool;
29    /// Check if the tensor is borrowed
30    fn borrowed() -> bool {
31        !Self::owned()
32    }
33    /// Check if the tensor is allocated in the host
34    fn host() -> bool;
35    /// Check if the tensor is allocated in the device
36    fn device() -> bool {
37        !Self::host()
38    }
39}
40/// A tensor that is owned
41pub trait OwnedTensorType: TensorType {}
42/// A tensor that is borrowed
43pub trait RefTensorType: TensorType {}
44/// A tensor that is allocated in the cpu / host platform
45pub trait HostTensorType: TensorType {}
46/// A tensor that is allocated in the device / gpu platform
47pub trait DeviceTensorType: TensorType {}
48/// A tensor that is mutable
49pub trait MutableTensorType: TensorType {}
50
51impl<H: HalideType> TensorType for Host<H> {
52    type H = H;
53    fn owned() -> bool {
54        true
55    }
56    fn host() -> bool {
57        true
58    }
59}
60impl<H: HalideType> TensorType for Device<H> {
61    type H = H;
62    fn owned() -> bool {
63        true
64    }
65    fn host() -> bool {
66        false
67    }
68}
69
70impl<T: TensorType> TensorType for Ref<'_, T> {
71    type H = T::H;
72    fn owned() -> bool {
73        false
74    }
75    fn host() -> bool {
76        T::host()
77    }
78}
79
80impl<T: TensorType> TensorType for RefMut<'_, T> {
81    type H = T::H;
82    fn owned() -> bool {
83        false
84    }
85    fn host() -> bool {
86        T::host()
87    }
88}
89
90impl<H: HalideType> DeviceTensorType for Device<H> {}
91impl<H: HalideType> HostTensorType for Host<H> {}
92impl<H: HalideType> OwnedTensorType for Device<H> {}
93impl<H: HalideType> OwnedTensorType for Host<H> {}
94impl<T: DeviceTensorType> DeviceTensorType for Ref<'_, T> {}
95impl<T: DeviceTensorType> DeviceTensorType for RefMut<'_, T> {}
96impl<T: HostTensorType> HostTensorType for Ref<'_, T> {}
97impl<T: HostTensorType> HostTensorType for RefMut<'_, T> {}
98impl<T: OwnedTensorType> MutableTensorType for T {}
99impl<T: TensorType> MutableTensorType for RefMut<'_, T> {}
100impl<T: TensorType> RefTensorType for Ref<'_, T> {}
101impl<T: TensorType> RefTensorType for RefMut<'_, T> {}
102
103/// A tensor that is owned by the cpu / host platform
104pub struct Host<T = f32> {
105    pub(crate) __marker: PhantomData<T>,
106}
107/// A tensor that is owned by the device / gpu platform
108pub struct Device<T = f32> {
109    pub(crate) __marker: PhantomData<T>,
110}
111/// A reference to a any tensor
112pub struct Ref<'t, T> {
113    pub(crate) __marker: PhantomData<&'t [T]>,
114}
115
116/// A mutable reference to a any tensor
117pub struct RefMut<'t, T> {
118    pub(crate) __marker: PhantomData<&'t mut [T]>,
119}
120
121/// A generic tensor that can of host / device / owned / borrowed
122pub struct Tensor<T: TensorType> {
123    pub(crate) tensor: *mut mnn_sys::Tensor,
124    __marker: PhantomData<T>,
125}
126
127impl<T: TensorType> Drop for Tensor<T> {
128    fn drop(&mut self) {
129        if T::owned() {
130            unsafe {
131                mnn_sys::Tensor_destroy(self.tensor);
132            }
133        }
134    }
135}
136
137impl<H: HalideType> Tensor<Host<H>> {
138    /// Get's a reference to an owned host tensor
139    pub fn as_ref(&self) -> Tensor<Ref<'_, Host<H>>> {
140        Tensor {
141            tensor: self.tensor,
142            __marker: PhantomData,
143        }
144    }
145}
146
147impl<H: HalideType> Tensor<Device<H>> {
148    /// Get's a reference to an owned device tensor
149    pub fn as_ref(&self) -> Tensor<Ref<'_, Device<H>>> {
150        Tensor {
151            tensor: self.tensor,
152            __marker: PhantomData,
153        }
154    }
155}
156
157impl<H: HalideType, T: TensorType<H = H>> ToOwned for Tensor<Ref<'_, T>> {
158    type Owned = Tensor<T>;
159
160    fn to_owned(&self) -> Self::Owned {
161        let tensor_ptr = unsafe { Tensor_clone(self.tensor) };
162        Tensor {
163            tensor: tensor_ptr,
164            __marker: PhantomData,
165        }
166    }
167}
168
169impl<'t, H: HalideType, T: TensorType<H = H>> Borrow<Tensor<Ref<'t, T>>> for Tensor<T> {
170    fn borrow(&self) -> &Tensor<Ref<'t, T>> {
171        unsafe { &*(self as *const Tensor<T> as *const Tensor<Ref<'t, T>>) }
172    }
173}
174
175impl<'t, H: HalideType, T: TensorType<H = H>> Borrow<Tensor<T>> for Tensor<Ref<'t, T>> {
176    fn borrow(&self) -> &'t Tensor<T> {
177        unsafe { &*(self as *const Tensor<Ref<'_, T>> as *const Tensor<T>) }
178    }
179}
180
181// impl<'h, H: HalideType, O> AsMut<Tensor<RefMut<'h, O>>> for Tensor<O>
182// where
183//     H: HalideType,
184//     O: OwnedTensorType + TensorType<H = H>,
185// {
186//     fn as_mut(&mut self) -> &mut Tensor<RefMut<'h, O>> {
187//
188//     }
189// }
190
191// impl<H: HalideType, O: OwnedTensorType> AsRef<Tensor<Ref<'_, O>>> for Tensor<O> {
192//     fn as_ref(&self) -> &Tensor<Ref<O>> {
193//         unsafe {
194//             // SAFETY: The tensor is guaranteed to be valid and immutable
195//             &*(self as *const Self as *const Tensor<Ref<O>>)
196//         }
197//     }
198// }
199
200/// The type of the tensor dimension  
201/// If you are manually specifying the shapes then this doesn't really matter  
202/// N -> Batch size
203/// C -> Channel
204/// H -> Height
205/// W -> Width
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
207pub enum DimensionType {
208    /// Caffe style dimensions or NCHW
209    Caffe,
210    /// Caffe style dimensions with channel packed in 4 bytes or NC4HW4
211    CaffeC4,
212    /// Tensorflow style dimensions or NHWC
213    TensorFlow,
214}
215
216impl DimensionType {
217    /// Tensorflow style dimensions or NHWC
218    pub const NHWC: Self = Self::TensorFlow;
219    /// Caffe style dimensions or NCHW
220    pub const NCHW: Self = Self::Caffe;
221    /// Caffe style dimensions with channel packed in 4 bytes or NC4HW4
222    pub const NC4HW4: Self = Self::CaffeC4;
223    pub(crate) fn to_mnn_sys(self) -> mnn_sys::DimensionType {
224        match self {
225            DimensionType::Caffe => mnn_sys::DimensionType::CAFFE,
226            DimensionType::CaffeC4 => mnn_sys::DimensionType::CAFFE_C4,
227            DimensionType::TensorFlow => mnn_sys::DimensionType::TENSORFLOW,
228        }
229    }
230}
231
232impl From<mnn_sys::DimensionType> for DimensionType {
233    fn from(dm: mnn_sys::DimensionType) -> Self {
234        match dm {
235            mnn_sys::DimensionType::CAFFE => DimensionType::Caffe,
236            mnn_sys::DimensionType::CAFFE_C4 => DimensionType::CaffeC4,
237            mnn_sys::DimensionType::TENSORFLOW => DimensionType::TensorFlow,
238        }
239    }
240}
241
242impl<T: TensorType> Tensor<T>
243where
244    T::H: HalideType,
245{
246    /// This function constructs a Tensor type from a raw pointer
247    ///# Safety
248    /// Since this constructs a Tensor from a raw pointer we have no way to guarantee that it's a
249    /// valid tensor or it's lifetime
250    pub unsafe fn from_ptr(tensor: *mut mnn_sys::Tensor) -> Self {
251        assert!(!tensor.is_null());
252        Self {
253            tensor,
254            __marker: PhantomData,
255        }
256    }
257    /// Copies the data from a host tensor to the self tensor
258    pub fn copy_from_host_tensor(&mut self, tensor: &Tensor<Host<T::H>>) -> Result<()> {
259        assert_eq!(self.size(), tensor.size(), "Tensor sizes do not match");
260        let ret = unsafe { Tensor_copyFromHostTensor(self.tensor, tensor.tensor) };
261        crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
262        Ok(())
263    }
264
265    /// Copies the data from the self tensor to a host tensor
266    pub fn copy_to_host_tensor(&self, tensor: &mut Tensor<Host<T::H>>) -> Result<()> {
267        assert_eq!(self.size(), tensor.size(), "Tensor sizes do not match");
268        let ret = unsafe { Tensor_copyToHostTensor(self.tensor, tensor.tensor) };
269        crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
270        Ok(())
271    }
272
273    /// Get the device id of the tensor
274    pub fn device_id(&self) -> u64 {
275        unsafe { Tensor_deviceId(self.tensor) }
276    }
277
278    /// Get the shape of the tensor
279    pub fn shape(&self) -> TensorShape {
280        unsafe { Tensor_shape(self.tensor) }.into()
281    }
282
283    /// Get the dimensions of the tensor
284    pub fn dimensions(&self) -> usize {
285        unsafe { Tensor_dimensions(self.tensor) as usize }
286    }
287
288    /// Get the width of the tensor
289    pub fn width(&self) -> u32 {
290        unsafe { Tensor_width(self.tensor) as u32 }
291    }
292
293    /// Get the height of the tensor
294    pub fn height(&self) -> u32 {
295        unsafe { Tensor_height(self.tensor) as u32 }
296    }
297
298    /// Get the channel size of the tensor
299    pub fn channel(&self) -> u32 {
300        unsafe { Tensor_channel(self.tensor) as u32 }
301    }
302
303    /// Get the batch size of the tensor
304    pub fn batch(&self) -> u32 {
305        unsafe { Tensor_batch(self.tensor) as u32 }
306    }
307
308    /// Get the size of the tensor when counted by bytes
309    pub fn size(&self) -> usize {
310        unsafe { Tensor_usize(self.tensor) }
311    }
312
313    /// Get the size of the tensor when counted by elements
314    pub fn element_size(&self) -> usize {
315        unsafe { Tensor_elementSize(self.tensor) as usize }
316    }
317
318    /// Print the shape of the tensor
319    pub fn print_shape(&self) {
320        unsafe {
321            Tensor_printShape(self.tensor);
322        }
323    }
324
325    /// Print the tensor
326    pub fn print(&self) {
327        unsafe {
328            Tensor_print(self.tensor);
329        }
330    }
331
332    /// Check if the tensor is dynamic and needs resizing
333    pub fn is_dynamic_unsized(&self) -> bool {
334        self.shape().as_ref().contains(&-1)
335    }
336
337    /// DO not use this function directly
338    /// # Safety
339    /// This is just provided as a 1:1 compat mostly for possible later use
340    pub unsafe fn halide_buffer(&self) -> *const halide_buffer_t {
341        unsafe { Tensor_buffer(self.tensor) }
342    }
343
344    /// Do not use this function directly
345    /// # Safety
346    /// This is just provided as a 1:1 compat mostly for possible later use
347    pub unsafe fn halide_buffer_mut(&self) -> *mut halide_buffer_t {
348        unsafe { Tensor_buffer_mut(self.tensor) }
349    }
350
351    /// Get the dimension type of the tensor
352    pub fn get_dimension_type(&self) -> DimensionType {
353        debug_assert!(!self.tensor.is_null());
354        From::from(unsafe { Tensor_getDimensionType(self.tensor) })
355    }
356
357    /// Get the data type of the tensor
358    pub fn get_type(&self) -> mnn_sys::halide_type_t {
359        unsafe { Tensor_getType(self.tensor) }
360    }
361
362    /// Check if the tensor is of the specified data type
363    pub fn is_type_of<H: HalideType>(&self) -> bool {
364        let htc = halide_type_of::<H>();
365        unsafe { Tensor_isTypeOf(self.tensor, htc) }
366    }
367
368    /// # Safety
369    /// This is very unsafe do not use this unless you know what you are doing
370    pub unsafe fn into_raw(self) -> RawTensor<'static> {
371        let out = RawTensor {
372            inner: self.tensor,
373            __marker: PhantomData,
374        };
375        core::mem::forget(self);
376        out
377    }
378}
379impl<T: MutableTensorType> Tensor<T>
380where
381    T::H: HalideType,
382{
383    /// Fill the tensor with the specified value
384    pub fn fill(&mut self, value: T::H)
385    where
386        T::H: Copy,
387    {
388        if T::host() {
389            let size = self.element_size();
390            assert!(self.is_type_of::<T::H>());
391            let result: &mut [T::H] = unsafe {
392                let data = mnn_sys::Tensor_host_mut(self.tensor).cast();
393                core::slice::from_raw_parts_mut(data, size)
394            };
395            result.fill(value);
396        } else if T::device() {
397            let shape = self.shape();
398            let dm_type = self.get_dimension_type();
399            let mut host = Tensor::new(shape, dm_type);
400            host.fill(value);
401            self.copy_from_host_tensor(&host)
402                .expect("Failed to copy data from host tensor");
403        } else {
404            unreachable!()
405        }
406    }
407}
408
409impl<T: HostTensorType> Tensor<T>
410where
411    T::H: HalideType,
412{
413    /// Try to map the device tensor to the host memory and get the slice
414    pub fn try_host(&self) -> Result<&[T::H]> {
415        let size = self.element_size();
416        ensure!(
417            self.is_type_of::<T::H>(),
418            ErrorKind::HalideTypeMismatch {
419                got: std::any::type_name::<T::H>(),
420            }
421        );
422        let result = unsafe {
423            let data = mnn_sys::Tensor_host(self.tensor).cast();
424            core::slice::from_raw_parts(data, size)
425        };
426        Ok(result)
427    }
428
429    /// Try to map the device tensor to the host memory and get the mutable slice
430    pub fn try_host_mut(&mut self) -> Result<&mut [T::H]> {
431        let size = self.element_size();
432        ensure!(
433            self.is_type_of::<T::H>(),
434            ErrorKind::HalideTypeMismatch {
435                got: std::any::type_name::<T::H>(),
436            }
437        );
438
439        let result = unsafe {
440            let data: *mut T::H = mnn_sys::Tensor_host_mut(self.tensor).cast();
441            debug_assert!(!data.is_null());
442            core::slice::from_raw_parts_mut(data, size)
443        };
444        Ok(result)
445    }
446
447    /// Get the host memory slice of the tensor
448    pub fn host(&self) -> &[T::H] {
449        self.try_host().expect("Failed to get tensor host")
450    }
451
452    /// Get the mutable host memory slice of the tensor
453    pub fn host_mut(&mut self) -> &mut [T::H] {
454        self.try_host_mut().expect("Failed to get tensor host_mut")
455    }
456}
457
458impl<T: DeviceTensorType> Tensor<T>
459where
460    T::H: HalideType,
461{
462    /// Try to wait for the device tensor to finish processing
463    pub fn wait(&self, map_type: MapType, finish: bool) {
464        unsafe {
465            Tensor_wait(self.tensor, map_type, finish as i32);
466        }
467    }
468
469    /// Create a host tensor from the device tensor with same dimensions and data type and
470    /// optionally copy the data from the device tensor
471    pub fn create_host_tensor_from_device(&self, copy_data: bool) -> Tensor<Host<T::H>> {
472        let shape = self.shape();
473        let dm_type = self.get_dimension_type();
474        let mut out = Tensor::new(shape, dm_type);
475
476        if copy_data {
477            self.copy_to_host_tensor(&mut out)
478                .expect("Failed to copy data from device tensor");
479        }
480        out
481    }
482}
483
484impl<T: OwnedTensorType> Tensor<T>
485where
486    T::H: HalideType,
487{
488    /// Create a new tensor with the specified shape and dimension type
489    pub fn new(shape: impl AsTensorShape, dm_type: DimensionType) -> Self {
490        let shape = shape.as_tensor_shape();
491        let tensor = unsafe {
492            if T::device() {
493                Tensor_createDevice(
494                    shape.shape.as_ptr(),
495                    shape.size,
496                    halide_type_of::<T::H>(),
497                    dm_type.to_mnn_sys(),
498                )
499            } else {
500                Tensor_createWith(
501                    shape.shape.as_ptr(),
502                    shape.size,
503                    halide_type_of::<T::H>(),
504                    core::ptr::null_mut(),
505                    dm_type.to_mnn_sys(),
506                )
507            }
508        };
509        debug_assert!(!tensor.is_null());
510        Self {
511            tensor,
512            __marker: PhantomData,
513        }
514    }
515}
516
517impl<T: OwnedTensorType> Clone for Tensor<T>
518where
519    T::H: HalideType,
520{
521    fn clone(&self) -> Tensor<T> {
522        let tensor_ptr = unsafe { Tensor_clone(self.tensor) };
523        Self {
524            tensor: tensor_ptr,
525            __marker: PhantomData,
526        }
527    }
528}
529
530/// A tensor shape
531#[derive(Clone, Copy)]
532#[repr(C)]
533pub struct TensorShape {
534    pub(crate) shape: [i32; 4],
535    pub(crate) size: usize,
536}
537
538impl From<mnn_sys::TensorShape> for TensorShape {
539    fn from(value: mnn_sys::TensorShape) -> Self {
540        Self {
541            shape: value.shape,
542            size: value.size,
543        }
544    }
545}
546
547impl From<TensorShape> for mnn_sys::TensorShape {
548    fn from(value: TensorShape) -> Self {
549        Self {
550            shape: value.shape,
551            size: value.size,
552        }
553    }
554}
555
556impl core::ops::Deref for TensorShape {
557    type Target = [i32];
558
559    fn deref(&self) -> &Self::Target {
560        &self.shape[..self.size]
561    }
562}
563
564impl core::ops::Index<usize> for TensorShape {
565    type Output = i32;
566
567    fn index(&self, index: usize) -> &Self::Output {
568        &self.shape[..self.size][index]
569    }
570}
571
572impl core::ops::IndexMut<usize> for TensorShape {
573    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
574        &mut self.shape[..self.size][index]
575    }
576}
577
578impl core::ops::DerefMut for TensorShape {
579    fn deref_mut(&mut self) -> &mut Self::Target {
580        &mut self.shape[..self.size]
581    }
582}
583
584impl core::fmt::Debug for TensorShape {
585    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
586        write!(f, "{:?}", &self.shape[..self.size])
587    }
588}
589
590/// A trait to convert any array-like type to a tensor shape
591pub trait AsTensorShape {
592    /// Convert the array-like type to a tensor shape
593    fn as_tensor_shape(&self) -> TensorShape;
594}
595
596impl<T: AsRef<[i32]>> AsTensorShape for T {
597    fn as_tensor_shape(&self) -> TensorShape {
598        let this = self.as_ref();
599        let size = std::cmp::min(this.len(), 4);
600        let mut shape = [1; 4];
601        shape[..size].copy_from_slice(&this[..size]);
602        TensorShape { shape, size }
603    }
604}
605
606impl AsTensorShape for TensorShape {
607    fn as_tensor_shape(&self) -> TensorShape {
608        *self
609    }
610}
611
612#[cfg(test)]
613mod as_tensor_shape_tests {
614    use super::AsTensorShape;
615    macro_rules! shape_test {
616        ($t:ty, $kind: expr, $value: expr) => {
617            eprintln!("Testing {} with {} shape", stringify!($t), $kind);
618            $value.as_tensor_shape();
619        };
620    }
621    #[test]
622    fn as_tensor_shape_test_vec() {
623        shape_test!(Vec<i32>, "small", vec![1, 2, 3]);
624        shape_test!(Vec<i32>, "large", vec![12, 23, 34, 45, 67]);
625    }
626    #[test]
627    fn as_tensor_shape_test_array() {
628        shape_test!([i32; 3], "small", [1, 2, 3]);
629        shape_test!([i32; 5], "large", [12, 23, 34, 45, 67]);
630    }
631    #[test]
632    fn as_tensor_shape_test_ref() {
633        shape_test!(&[i32], "small", &[1, 2, 3]);
634        shape_test!(&[i32], "large", &[12, 23, 34, 45, 67]);
635    }
636}
637
638#[cfg(test)]
639mod tensor_tests {
640    #[test]
641    #[should_panic]
642    fn unsafe_nullptr_tensor() {
643        unsafe {
644            super::Tensor::<super::Host<i32>>::from_ptr(core::ptr::null_mut());
645        }
646    }
647}
648
649impl<T: HostTensorType + RefTensorType> Tensor<T>
650where
651    T::H: HalideType,
652{
653    /// Try to create a ref tensor from any array-like type
654    pub fn borrowed(shape: impl AsTensorShape, input: impl AsRef<[T::H]>) -> Self {
655        let shape = shape.as_tensor_shape();
656        let input = input.as_ref();
657        let tensor = unsafe {
658            Tensor_createWith(
659                shape.shape.as_ptr(),
660                shape.size,
661                halide_type_of::<T::H>(),
662                input.as_ptr().cast_mut().cast(),
663                DimensionType::Caffe.to_mnn_sys(),
664            )
665        };
666        debug_assert!(!tensor.is_null());
667        Self {
668            tensor,
669            __marker: PhantomData,
670        }
671    }
672
673    /// Try to create a mutable ref tensor from any array-like type
674    pub fn borrowed_mut(shape: impl AsTensorShape, mut input: impl AsMut<[T::H]>) -> Self {
675        let shape = shape.as_tensor_shape();
676        let input = input.as_mut();
677        let tensor = unsafe {
678            Tensor_createWith(
679                shape.shape.as_ptr(),
680                shape.size,
681                halide_type_of::<T::H>(),
682                input.as_mut_ptr().cast(),
683                DimensionType::Caffe.to_mnn_sys(),
684            )
685        };
686        debug_assert!(!tensor.is_null());
687        Self {
688            tensor,
689            __marker: PhantomData,
690        }
691    }
692}
693
694#[test]
695fn test_tensor_borrowed() {
696    let shape = [1, 2, 3];
697    let data = vec![1, 2, 3, 4, 5, 6];
698    let tensor = Tensor::<Ref<Host<i32>>>::borrowed(&shape, &data);
699    assert_eq!(tensor.shape().as_ref(), shape);
700    assert_eq!(tensor.host(), data.as_slice());
701}
702
703#[test]
704fn test_tensor_borrow_mut() {
705    let shape = [1, 2, 3];
706    let mut data = vec![1, 2, 3, 4, 5, 6];
707    let mut tensor = Tensor::<RefMut<Host<i32>>>::borrowed_mut(&shape, &mut data);
708    tensor.host_mut().fill(1);
709    assert_eq!(data, &[1, 1, 1, 1, 1, 1]);
710}