1use crate::prelude::*;
2use core::marker::PhantomData;
3use mnn_sys::*;
4use std::borrow::Borrow;
5pub(crate) mod list;
6mod raw;
7pub use raw::RawTensor;
8
9use mnn_sys::HalideType;
10
11mod seal {
12 pub trait Sealed {}
13}
14macro_rules! seal {
15 ($($name:ty),*) => {
16 $(
17 impl<T> seal::Sealed for $name {}
18 )*
19 };
20 }
21seal!(Host<T>, Device<T>, Ref<'_, T>, RefMut<'_, T>);
22
23pub trait TensorType: seal::Sealed {
25 type H;
27 fn owned() -> bool;
29 fn borrowed() -> bool {
31 !Self::owned()
32 }
33 fn host() -> bool;
35 fn device() -> bool {
37 !Self::host()
38 }
39}
40pub trait OwnedTensorType: TensorType {}
42pub trait RefTensorType: TensorType {}
44pub trait HostTensorType: TensorType {}
46pub trait DeviceTensorType: TensorType {}
48pub trait MutableTensorType: TensorType {}
50
51impl<H: HalideType> TensorType for Host<H> {
52 type H = H;
53 fn owned() -> bool {
54 true
55 }
56 fn host() -> bool {
57 true
58 }
59}
60impl<H: HalideType> TensorType for Device<H> {
61 type H = H;
62 fn owned() -> bool {
63 true
64 }
65 fn host() -> bool {
66 false
67 }
68}
69
70impl<T: TensorType> TensorType for Ref<'_, T> {
71 type H = T::H;
72 fn owned() -> bool {
73 false
74 }
75 fn host() -> bool {
76 T::host()
77 }
78}
79
80impl<T: TensorType> TensorType for RefMut<'_, T> {
81 type H = T::H;
82 fn owned() -> bool {
83 false
84 }
85 fn host() -> bool {
86 T::host()
87 }
88}
89
90impl<H: HalideType> DeviceTensorType for Device<H> {}
91impl<H: HalideType> HostTensorType for Host<H> {}
92impl<H: HalideType> OwnedTensorType for Device<H> {}
93impl<H: HalideType> OwnedTensorType for Host<H> {}
94impl<T: DeviceTensorType> DeviceTensorType for Ref<'_, T> {}
95impl<T: DeviceTensorType> DeviceTensorType for RefMut<'_, T> {}
96impl<T: HostTensorType> HostTensorType for Ref<'_, T> {}
97impl<T: HostTensorType> HostTensorType for RefMut<'_, T> {}
98impl<T: OwnedTensorType> MutableTensorType for T {}
99impl<T: TensorType> MutableTensorType for RefMut<'_, T> {}
100impl<T: TensorType> RefTensorType for Ref<'_, T> {}
101impl<T: TensorType> RefTensorType for RefMut<'_, T> {}
102
103pub struct Host<T = f32> {
105 pub(crate) __marker: PhantomData<T>,
106}
107pub struct Device<T = f32> {
109 pub(crate) __marker: PhantomData<T>,
110}
111pub struct Ref<'t, T> {
113 pub(crate) __marker: PhantomData<&'t [T]>,
114}
115
116pub struct RefMut<'t, T> {
118 pub(crate) __marker: PhantomData<&'t mut [T]>,
119}
120
121pub struct Tensor<T: TensorType> {
123 pub(crate) tensor: *mut mnn_sys::Tensor,
124 __marker: PhantomData<T>,
125}
126
127impl<T: TensorType> Drop for Tensor<T> {
128 fn drop(&mut self) {
129 if T::owned() {
130 unsafe {
131 mnn_sys::Tensor_destroy(self.tensor);
132 }
133 }
134 }
135}
136
137impl<H: HalideType> Tensor<Host<H>> {
138 pub fn as_ref(&self) -> Tensor<Ref<'_, Host<H>>> {
140 Tensor {
141 tensor: self.tensor,
142 __marker: PhantomData,
143 }
144 }
145}
146
147impl<H: HalideType> Tensor<Device<H>> {
148 pub fn as_ref(&self) -> Tensor<Ref<'_, Device<H>>> {
150 Tensor {
151 tensor: self.tensor,
152 __marker: PhantomData,
153 }
154 }
155}
156
157impl<H: HalideType, T: TensorType<H = H>> ToOwned for Tensor<Ref<'_, T>> {
158 type Owned = Tensor<T>;
159
160 fn to_owned(&self) -> Self::Owned {
161 let tensor_ptr = unsafe { Tensor_clone(self.tensor) };
162 Tensor {
163 tensor: tensor_ptr,
164 __marker: PhantomData,
165 }
166 }
167}
168
169impl<'t, H: HalideType, T: TensorType<H = H>> Borrow<Tensor<Ref<'t, T>>> for Tensor<T> {
170 fn borrow(&self) -> &Tensor<Ref<'t, T>> {
171 unsafe { &*(self as *const Tensor<T> as *const Tensor<Ref<'t, T>>) }
172 }
173}
174
175impl<'t, H: HalideType, T: TensorType<H = H>> Borrow<Tensor<T>> for Tensor<Ref<'t, T>> {
176 fn borrow(&self) -> &'t Tensor<T> {
177 unsafe { &*(self as *const Tensor<Ref<'_, T>> as *const Tensor<T>) }
178 }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
207pub enum DimensionType {
208 Caffe,
210 CaffeC4,
212 TensorFlow,
214}
215
216impl DimensionType {
217 pub const NHWC: Self = Self::TensorFlow;
219 pub const NCHW: Self = Self::Caffe;
221 pub const NC4HW4: Self = Self::CaffeC4;
223 pub(crate) fn to_mnn_sys(self) -> mnn_sys::DimensionType {
224 match self {
225 DimensionType::Caffe => mnn_sys::DimensionType::CAFFE,
226 DimensionType::CaffeC4 => mnn_sys::DimensionType::CAFFE_C4,
227 DimensionType::TensorFlow => mnn_sys::DimensionType::TENSORFLOW,
228 }
229 }
230}
231
232impl From<mnn_sys::DimensionType> for DimensionType {
233 fn from(dm: mnn_sys::DimensionType) -> Self {
234 match dm {
235 mnn_sys::DimensionType::CAFFE => DimensionType::Caffe,
236 mnn_sys::DimensionType::CAFFE_C4 => DimensionType::CaffeC4,
237 mnn_sys::DimensionType::TENSORFLOW => DimensionType::TensorFlow,
238 }
239 }
240}
241
242impl<T: TensorType> Tensor<T>
243where
244 T::H: HalideType,
245{
246 pub unsafe fn from_ptr(tensor: *mut mnn_sys::Tensor) -> Self {
251 assert!(!tensor.is_null());
252 Self {
253 tensor,
254 __marker: PhantomData,
255 }
256 }
257 pub fn copy_from_host_tensor(&mut self, tensor: &Tensor<Host<T::H>>) -> Result<()> {
259 assert_eq!(self.size(), tensor.size(), "Tensor sizes do not match");
260 let ret = unsafe { Tensor_copyFromHostTensor(self.tensor, tensor.tensor) };
261 crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
262 Ok(())
263 }
264
265 pub fn copy_to_host_tensor(&self, tensor: &mut Tensor<Host<T::H>>) -> Result<()> {
267 assert_eq!(self.size(), tensor.size(), "Tensor sizes do not match");
268 let ret = unsafe { Tensor_copyToHostTensor(self.tensor, tensor.tensor) };
269 crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
270 Ok(())
271 }
272
273 pub fn device_id(&self) -> u64 {
275 unsafe { Tensor_deviceId(self.tensor) }
276 }
277
278 pub fn shape(&self) -> TensorShape {
280 unsafe { Tensor_shape(self.tensor) }.into()
281 }
282
283 pub fn dimensions(&self) -> usize {
285 unsafe { Tensor_dimensions(self.tensor) as usize }
286 }
287
288 pub fn width(&self) -> u32 {
290 unsafe { Tensor_width(self.tensor) as u32 }
291 }
292
293 pub fn height(&self) -> u32 {
295 unsafe { Tensor_height(self.tensor) as u32 }
296 }
297
298 pub fn channel(&self) -> u32 {
300 unsafe { Tensor_channel(self.tensor) as u32 }
301 }
302
303 pub fn batch(&self) -> u32 {
305 unsafe { Tensor_batch(self.tensor) as u32 }
306 }
307
308 pub fn size(&self) -> usize {
310 unsafe { Tensor_usize(self.tensor) }
311 }
312
313 pub fn element_size(&self) -> usize {
315 unsafe { Tensor_elementSize(self.tensor) as usize }
316 }
317
318 pub fn print_shape(&self) {
320 unsafe {
321 Tensor_printShape(self.tensor);
322 }
323 }
324
325 pub fn print(&self) {
327 unsafe {
328 Tensor_print(self.tensor);
329 }
330 }
331
332 pub fn is_dynamic_unsized(&self) -> bool {
334 self.shape().as_ref().contains(&-1)
335 }
336
337 pub unsafe fn halide_buffer(&self) -> *const halide_buffer_t {
341 unsafe { Tensor_buffer(self.tensor) }
342 }
343
344 pub unsafe fn halide_buffer_mut(&self) -> *mut halide_buffer_t {
348 unsafe { Tensor_buffer_mut(self.tensor) }
349 }
350
351 pub fn get_dimension_type(&self) -> DimensionType {
353 debug_assert!(!self.tensor.is_null());
354 From::from(unsafe { Tensor_getDimensionType(self.tensor) })
355 }
356
357 pub fn get_type(&self) -> mnn_sys::halide_type_t {
359 unsafe { Tensor_getType(self.tensor) }
360 }
361
362 pub fn is_type_of<H: HalideType>(&self) -> bool {
364 let htc = halide_type_of::<H>();
365 unsafe { Tensor_isTypeOf(self.tensor, htc) }
366 }
367
368 pub unsafe fn into_raw(self) -> RawTensor<'static> {
371 let out = RawTensor {
372 inner: self.tensor,
373 __marker: PhantomData,
374 };
375 core::mem::forget(self);
376 out
377 }
378}
379impl<T: MutableTensorType> Tensor<T>
380where
381 T::H: HalideType,
382{
383 pub fn fill(&mut self, value: T::H)
385 where
386 T::H: Copy,
387 {
388 if T::host() {
389 let size = self.element_size();
390 assert!(self.is_type_of::<T::H>());
391 let result: &mut [T::H] = unsafe {
392 let data = mnn_sys::Tensor_host_mut(self.tensor).cast();
393 core::slice::from_raw_parts_mut(data, size)
394 };
395 result.fill(value);
396 } else if T::device() {
397 let shape = self.shape();
398 let dm_type = self.get_dimension_type();
399 let mut host = Tensor::new(shape, dm_type);
400 host.fill(value);
401 self.copy_from_host_tensor(&host)
402 .expect("Failed to copy data from host tensor");
403 } else {
404 unreachable!()
405 }
406 }
407}
408
409impl<T: HostTensorType> Tensor<T>
410where
411 T::H: HalideType,
412{
413 pub fn try_host(&self) -> Result<&[T::H]> {
415 let size = self.element_size();
416 ensure!(
417 self.is_type_of::<T::H>(),
418 ErrorKind::HalideTypeMismatch {
419 got: std::any::type_name::<T::H>(),
420 }
421 );
422 let result = unsafe {
423 let data = mnn_sys::Tensor_host(self.tensor).cast();
424 core::slice::from_raw_parts(data, size)
425 };
426 Ok(result)
427 }
428
429 pub fn try_host_mut(&mut self) -> Result<&mut [T::H]> {
431 let size = self.element_size();
432 ensure!(
433 self.is_type_of::<T::H>(),
434 ErrorKind::HalideTypeMismatch {
435 got: std::any::type_name::<T::H>(),
436 }
437 );
438
439 let result = unsafe {
440 let data: *mut T::H = mnn_sys::Tensor_host_mut(self.tensor).cast();
441 debug_assert!(!data.is_null());
442 core::slice::from_raw_parts_mut(data, size)
443 };
444 Ok(result)
445 }
446
447 pub fn host(&self) -> &[T::H] {
449 self.try_host().expect("Failed to get tensor host")
450 }
451
452 pub fn host_mut(&mut self) -> &mut [T::H] {
454 self.try_host_mut().expect("Failed to get tensor host_mut")
455 }
456}
457
458impl<T: DeviceTensorType> Tensor<T>
459where
460 T::H: HalideType,
461{
462 pub fn wait(&self, map_type: MapType, finish: bool) {
464 unsafe {
465 Tensor_wait(self.tensor, map_type, finish as i32);
466 }
467 }
468
469 pub fn create_host_tensor_from_device(&self, copy_data: bool) -> Tensor<Host<T::H>> {
472 let shape = self.shape();
473 let dm_type = self.get_dimension_type();
474 let mut out = Tensor::new(shape, dm_type);
475
476 if copy_data {
477 self.copy_to_host_tensor(&mut out)
478 .expect("Failed to copy data from device tensor");
479 }
480 out
481 }
482}
483
484impl<T: OwnedTensorType> Tensor<T>
485where
486 T::H: HalideType,
487{
488 pub fn new(shape: impl AsTensorShape, dm_type: DimensionType) -> Self {
490 let shape = shape.as_tensor_shape();
491 let tensor = unsafe {
492 if T::device() {
493 Tensor_createDevice(
494 shape.shape.as_ptr(),
495 shape.size,
496 halide_type_of::<T::H>(),
497 dm_type.to_mnn_sys(),
498 )
499 } else {
500 Tensor_createWith(
501 shape.shape.as_ptr(),
502 shape.size,
503 halide_type_of::<T::H>(),
504 core::ptr::null_mut(),
505 dm_type.to_mnn_sys(),
506 )
507 }
508 };
509 debug_assert!(!tensor.is_null());
510 Self {
511 tensor,
512 __marker: PhantomData,
513 }
514 }
515}
516
517impl<T: OwnedTensorType> Clone for Tensor<T>
518where
519 T::H: HalideType,
520{
521 fn clone(&self) -> Tensor<T> {
522 let tensor_ptr = unsafe { Tensor_clone(self.tensor) };
523 Self {
524 tensor: tensor_ptr,
525 __marker: PhantomData,
526 }
527 }
528}
529
530#[derive(Clone, Copy)]
532#[repr(C)]
533pub struct TensorShape {
534 pub(crate) shape: [i32; 4],
535 pub(crate) size: usize,
536}
537
538impl From<mnn_sys::TensorShape> for TensorShape {
539 fn from(value: mnn_sys::TensorShape) -> Self {
540 Self {
541 shape: value.shape,
542 size: value.size,
543 }
544 }
545}
546
547impl From<TensorShape> for mnn_sys::TensorShape {
548 fn from(value: TensorShape) -> Self {
549 Self {
550 shape: value.shape,
551 size: value.size,
552 }
553 }
554}
555
556impl core::ops::Deref for TensorShape {
557 type Target = [i32];
558
559 fn deref(&self) -> &Self::Target {
560 &self.shape[..self.size]
561 }
562}
563
564impl core::ops::Index<usize> for TensorShape {
565 type Output = i32;
566
567 fn index(&self, index: usize) -> &Self::Output {
568 &self.shape[..self.size][index]
569 }
570}
571
572impl core::ops::IndexMut<usize> for TensorShape {
573 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
574 &mut self.shape[..self.size][index]
575 }
576}
577
578impl core::ops::DerefMut for TensorShape {
579 fn deref_mut(&mut self) -> &mut Self::Target {
580 &mut self.shape[..self.size]
581 }
582}
583
584impl core::fmt::Debug for TensorShape {
585 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
586 write!(f, "{:?}", &self.shape[..self.size])
587 }
588}
589
590pub trait AsTensorShape {
592 fn as_tensor_shape(&self) -> TensorShape;
594}
595
596impl<T: AsRef<[i32]>> AsTensorShape for T {
597 fn as_tensor_shape(&self) -> TensorShape {
598 let this = self.as_ref();
599 let size = std::cmp::min(this.len(), 4);
600 let mut shape = [1; 4];
601 shape[..size].copy_from_slice(&this[..size]);
602 TensorShape { shape, size }
603 }
604}
605
606impl AsTensorShape for TensorShape {
607 fn as_tensor_shape(&self) -> TensorShape {
608 *self
609 }
610}
611
612#[cfg(test)]
613mod as_tensor_shape_tests {
614 use super::AsTensorShape;
615 macro_rules! shape_test {
616 ($t:ty, $kind: expr, $value: expr) => {
617 eprintln!("Testing {} with {} shape", stringify!($t), $kind);
618 $value.as_tensor_shape();
619 };
620 }
621 #[test]
622 fn as_tensor_shape_test_vec() {
623 shape_test!(Vec<i32>, "small", vec![1, 2, 3]);
624 shape_test!(Vec<i32>, "large", vec![12, 23, 34, 45, 67]);
625 }
626 #[test]
627 fn as_tensor_shape_test_array() {
628 shape_test!([i32; 3], "small", [1, 2, 3]);
629 shape_test!([i32; 5], "large", [12, 23, 34, 45, 67]);
630 }
631 #[test]
632 fn as_tensor_shape_test_ref() {
633 shape_test!(&[i32], "small", &[1, 2, 3]);
634 shape_test!(&[i32], "large", &[12, 23, 34, 45, 67]);
635 }
636}
637
638#[cfg(test)]
639mod tensor_tests {
640 #[test]
641 #[should_panic]
642 fn unsafe_nullptr_tensor() {
643 unsafe {
644 super::Tensor::<super::Host<i32>>::from_ptr(core::ptr::null_mut());
645 }
646 }
647}
648
649impl<T: HostTensorType + RefTensorType> Tensor<T>
650where
651 T::H: HalideType,
652{
653 pub fn borrowed(shape: impl AsTensorShape, input: impl AsRef<[T::H]>) -> Self {
655 let shape = shape.as_tensor_shape();
656 let input = input.as_ref();
657 let tensor = unsafe {
658 Tensor_createWith(
659 shape.shape.as_ptr(),
660 shape.size,
661 halide_type_of::<T::H>(),
662 input.as_ptr().cast_mut().cast(),
663 DimensionType::Caffe.to_mnn_sys(),
664 )
665 };
666 debug_assert!(!tensor.is_null());
667 Self {
668 tensor,
669 __marker: PhantomData,
670 }
671 }
672
673 pub fn borrowed_mut(shape: impl AsTensorShape, mut input: impl AsMut<[T::H]>) -> Self {
675 let shape = shape.as_tensor_shape();
676 let input = input.as_mut();
677 let tensor = unsafe {
678 Tensor_createWith(
679 shape.shape.as_ptr(),
680 shape.size,
681 halide_type_of::<T::H>(),
682 input.as_mut_ptr().cast(),
683 DimensionType::Caffe.to_mnn_sys(),
684 )
685 };
686 debug_assert!(!tensor.is_null());
687 Self {
688 tensor,
689 __marker: PhantomData,
690 }
691 }
692}
693
694#[test]
695fn test_tensor_borrowed() {
696 let shape = [1, 2, 3];
697 let data = vec![1, 2, 3, 4, 5, 6];
698 let tensor = Tensor::<Ref<Host<i32>>>::borrowed(&shape, &data);
699 assert_eq!(tensor.shape().as_ref(), shape);
700 assert_eq!(tensor.host(), data.as_slice());
701}
702
703#[test]
704fn test_tensor_borrow_mut() {
705 let shape = [1, 2, 3];
706 let mut data = vec![1, 2, 3, 4, 5, 6];
707 let mut tensor = Tensor::<RefMut<Host<i32>>>::borrowed_mut(&shape, &mut data);
708 tensor.host_mut().fill(1);
709 assert_eq!(data, &[1, 1, 1, 1, 1, 1]);
710}