1use crate::prelude::*;
2use core::marker::PhantomData;
3use mnn_sys::*;
4pub(crate) mod list;
5mod raw;
6pub use raw::RawTensor;
7
8use mnn_sys::HalideType;
9
10mod seal {
11 pub trait Sealed {}
12}
13macro_rules! seal {
14 ($($name:ty),*) => {
15 $(
16 impl<T> seal::Sealed for $name {}
17 )*
18 };
19 }
20seal!(Host<T>, Device<T>, Ref<'_, T>, RefMut<'_, T>);
21
22pub trait TensorType: seal::Sealed {
24 type H;
26 fn owned() -> bool;
28 fn borrowed() -> bool {
30 !Self::owned()
31 }
32 fn host() -> bool;
34 fn device() -> bool {
36 !Self::host()
37 }
38}
39pub trait OwnedTensorType: TensorType {}
41pub trait RefTensorType: TensorType {}
43pub trait HostTensorType: TensorType {}
45pub trait DeviceTensorType: TensorType {}
47pub trait MutableTensorType: TensorType {}
49
50impl<H: HalideType> TensorType for Host<H> {
51 type H = H;
52 fn owned() -> bool {
53 true
54 }
55 fn host() -> bool {
56 true
57 }
58}
59impl<H: HalideType> TensorType for Device<H> {
60 type H = H;
61 fn owned() -> bool {
62 true
63 }
64 fn host() -> bool {
65 false
66 }
67}
68
69impl<T: TensorType> TensorType for Ref<'_, T> {
70 type H = T::H;
71 fn owned() -> bool {
72 false
73 }
74 fn host() -> bool {
75 T::host()
76 }
77}
78
79impl<T: TensorType> TensorType for RefMut<'_, T> {
80 type H = T::H;
81 fn owned() -> bool {
82 false
83 }
84 fn host() -> bool {
85 T::host()
86 }
87}
88
89impl<H: HalideType> DeviceTensorType for Device<H> {}
90impl<H: HalideType> HostTensorType for Host<H> {}
91impl<H: HalideType> OwnedTensorType for Device<H> {}
92impl<H: HalideType> OwnedTensorType for Host<H> {}
93impl<T: DeviceTensorType> DeviceTensorType for Ref<'_, T> {}
94impl<T: DeviceTensorType> DeviceTensorType for RefMut<'_, T> {}
95impl<T: HostTensorType> HostTensorType for Ref<'_, T> {}
96impl<T: HostTensorType> HostTensorType for RefMut<'_, T> {}
97impl<T: OwnedTensorType> MutableTensorType for T {}
98impl<T: TensorType> MutableTensorType for RefMut<'_, T> {}
99impl<T: TensorType> RefTensorType for Ref<'_, T> {}
100impl<T: TensorType> RefTensorType for RefMut<'_, T> {}
101
102pub struct Host<T = f32> {
104 pub(crate) __marker: PhantomData<T>,
105}
106pub struct Device<T = f32> {
108 pub(crate) __marker: PhantomData<T>,
109}
110pub struct Ref<'t, T> {
112 pub(crate) __marker: PhantomData<&'t [T]>,
113}
114
115pub struct RefMut<'t, T> {
117 pub(crate) __marker: PhantomData<&'t mut [T]>,
118}
119
120pub struct Tensor<T: TensorType> {
122 pub(crate) tensor: *mut mnn_sys::Tensor,
123 __marker: PhantomData<T>,
124}
125
126impl<T: TensorType> Drop for Tensor<T> {
127 fn drop(&mut self) {
128 if T::owned() {
129 unsafe {
130 mnn_sys::Tensor_destroy(self.tensor);
131 }
132 }
133 }
134}
135
136impl<H: HalideType> Tensor<Host<H>> {
137 pub fn as_ref(&self) -> Tensor<Ref<'_, Host<H>>> {
139 Tensor {
140 tensor: self.tensor,
141 __marker: PhantomData,
142 }
143 }
144}
145
146impl<H: HalideType> Tensor<Device<H>> {
147 pub fn as_ref(&self) -> Tensor<Ref<'_, Device<H>>> {
149 Tensor {
150 tensor: self.tensor,
151 __marker: PhantomData,
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
163pub enum DimensionType {
164 Caffe,
166 CaffeC4,
168 TensorFlow,
170}
171
172impl DimensionType {
173 pub const NHWC: Self = Self::TensorFlow;
175 pub const NCHW: Self = Self::Caffe;
177 pub const NC4HW4: Self = Self::CaffeC4;
179 pub(crate) fn to_mnn_sys(self) -> mnn_sys::DimensionType {
180 match self {
181 DimensionType::Caffe => mnn_sys::DimensionType::CAFFE,
182 DimensionType::CaffeC4 => mnn_sys::DimensionType::CAFFE_C4,
183 DimensionType::TensorFlow => mnn_sys::DimensionType::TENSORFLOW,
184 }
185 }
186}
187
188impl From<mnn_sys::DimensionType> for DimensionType {
189 fn from(dm: mnn_sys::DimensionType) -> Self {
190 match dm {
191 mnn_sys::DimensionType::CAFFE => DimensionType::Caffe,
192 mnn_sys::DimensionType::CAFFE_C4 => DimensionType::CaffeC4,
193 mnn_sys::DimensionType::TENSORFLOW => DimensionType::TensorFlow,
194 }
195 }
196}
197
198impl<T: TensorType> Tensor<T>
199where
200 T::H: HalideType,
201{
202 pub unsafe fn from_ptr(tensor: *mut mnn_sys::Tensor) -> Self {
207 assert!(!tensor.is_null());
208 Self {
209 tensor,
210 __marker: PhantomData,
211 }
212 }
213 pub fn copy_from_host_tensor(&mut self, tensor: &Tensor<Host<T::H>>) -> Result<()> {
215 let ret = unsafe { Tensor_copyFromHostTensor(self.tensor, tensor.tensor) };
216 crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
217 Ok(())
218 }
219
220 pub fn copy_to_host_tensor(&self, tensor: &mut Tensor<Host<T::H>>) -> Result<()> {
222 let ret = unsafe { Tensor_copyToHostTensor(self.tensor, tensor.tensor) };
223 crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
224 Ok(())
225 }
226
227 pub fn device_id(&self) -> u64 {
229 unsafe { Tensor_deviceId(self.tensor) }
230 }
231
232 pub fn shape(&self) -> TensorShape {
234 unsafe { Tensor_shape(self.tensor) }.into()
235 }
236
237 pub fn dimensions(&self) -> usize {
239 unsafe { Tensor_dimensions(self.tensor) as usize }
240 }
241
242 pub fn width(&self) -> u32 {
244 unsafe { Tensor_width(self.tensor) as u32 }
245 }
246
247 pub fn height(&self) -> u32 {
249 unsafe { Tensor_height(self.tensor) as u32 }
250 }
251
252 pub fn channel(&self) -> u32 {
254 unsafe { Tensor_channel(self.tensor) as u32 }
255 }
256
257 pub fn batch(&self) -> u32 {
259 unsafe { Tensor_batch(self.tensor) as u32 }
260 }
261
262 pub fn size(&self) -> usize {
264 unsafe { Tensor_usize(self.tensor) }
265 }
266
267 pub fn element_size(&self) -> usize {
269 unsafe { Tensor_elementSize(self.tensor) as usize }
270 }
271
272 pub fn print_shape(&self) {
274 unsafe {
275 Tensor_printShape(self.tensor);
276 }
277 }
278
279 pub fn print(&self) {
281 unsafe {
282 Tensor_print(self.tensor);
283 }
284 }
285
286 pub fn is_dynamic_unsized(&self) -> bool {
288 self.shape().as_ref().contains(&-1)
289 }
290
291 pub unsafe fn halide_buffer(&self) -> *const halide_buffer_t {
295 unsafe { Tensor_buffer(self.tensor) }
296 }
297
298 pub unsafe fn halide_buffer_mut(&self) -> *mut halide_buffer_t {
302 unsafe { Tensor_buffer_mut(self.tensor) }
303 }
304
305 pub fn get_dimension_type(&self) -> DimensionType {
307 debug_assert!(!self.tensor.is_null());
308 From::from(unsafe { Tensor_getDimensionType(self.tensor) })
309 }
310
311 pub fn get_type(&self) -> mnn_sys::halide_type_t {
313 unsafe { Tensor_getType(self.tensor) }
314 }
315
316 pub fn is_type_of<H: HalideType>(&self) -> bool {
318 let htc = halide_type_of::<H>();
319 unsafe { Tensor_isTypeOf(self.tensor, htc) }
320 }
321
322 pub unsafe fn into_raw(self) -> RawTensor<'static> {
325 let out = RawTensor {
326 inner: self.tensor,
327 __marker: PhantomData,
328 };
329 core::mem::forget(self);
330 out
331 }
332}
333impl<T: MutableTensorType> Tensor<T>
334where
335 T::H: HalideType,
336{
337 pub fn fill(&mut self, value: T::H)
339 where
340 T::H: Copy,
341 {
342 if T::host() {
343 let size = self.element_size();
344 assert!(self.is_type_of::<T::H>());
345 let result: &mut [T::H] = unsafe {
346 let data = mnn_sys::Tensor_host_mut(self.tensor).cast();
347 core::slice::from_raw_parts_mut(data, size)
348 };
349 result.fill(value);
350 } else if T::device() {
351 let shape = self.shape();
352 let dm_type = self.get_dimension_type();
353 let mut host = Tensor::new(shape, dm_type);
354 host.fill(value);
355 self.copy_from_host_tensor(&host)
356 .expect("Failed to copy data from host tensor");
357 } else {
358 unreachable!()
359 }
360 }
361}
362
363impl<T: HostTensorType> Tensor<T>
364where
365 T::H: HalideType,
366{
367 pub fn try_host(&self) -> Result<&[T::H]> {
369 let size = self.element_size();
370 ensure!(
371 self.is_type_of::<T::H>(),
372 ErrorKind::HalideTypeMismatch {
373 got: std::any::type_name::<T::H>(),
374 }
375 );
376 let result = unsafe {
377 let data = mnn_sys::Tensor_host(self.tensor).cast();
378 core::slice::from_raw_parts(data, size)
379 };
380 Ok(result)
381 }
382
383 pub fn try_host_mut(&mut self) -> Result<&mut [T::H]> {
385 let size = self.element_size();
386 ensure!(
387 self.is_type_of::<T::H>(),
388 ErrorKind::HalideTypeMismatch {
389 got: std::any::type_name::<T::H>(),
390 }
391 );
392
393 let result = unsafe {
394 let data: *mut T::H = mnn_sys::Tensor_host_mut(self.tensor).cast();
395 debug_assert!(!data.is_null());
396 core::slice::from_raw_parts_mut(data, size)
397 };
398 Ok(result)
399 }
400
401 pub fn host(&self) -> &[T::H] {
403 self.try_host().expect("Failed to get tensor host")
404 }
405
406 pub fn host_mut(&mut self) -> &mut [T::H] {
408 self.try_host_mut().expect("Failed to get tensor host_mut")
409 }
410}
411
412impl<T: DeviceTensorType> Tensor<T>
413where
414 T::H: HalideType,
415{
416 pub fn wait(&self, map_type: MapType, finish: bool) {
418 unsafe {
419 Tensor_wait(self.tensor, map_type, finish as i32);
420 }
421 }
422
423 pub fn create_host_tensor_from_device(&self, copy_data: bool) -> Tensor<Host<T::H>> {
426 let shape = self.shape();
427 let dm_type = self.get_dimension_type();
428 let mut out = Tensor::new(shape, dm_type);
429
430 if copy_data {
431 self.copy_to_host_tensor(&mut out)
432 .expect("Failed to copy data from device tensor");
433 }
434 out
435 }
436}
437
438impl<T: OwnedTensorType> Tensor<T>
439where
440 T::H: HalideType,
441{
442 pub fn new(shape: impl AsTensorShape, dm_type: DimensionType) -> Self {
444 let shape = shape.as_tensor_shape();
445 let tensor = unsafe {
446 if T::device() {
447 Tensor_createDevice(
448 shape.shape.as_ptr(),
449 shape.size,
450 halide_type_of::<T::H>(),
451 dm_type.to_mnn_sys(),
452 )
453 } else {
454 Tensor_createWith(
455 shape.shape.as_ptr(),
456 shape.size,
457 halide_type_of::<T::H>(),
458 core::ptr::null_mut(),
459 dm_type.to_mnn_sys(),
460 )
461 }
462 };
463 debug_assert!(!tensor.is_null());
464 Self {
465 tensor,
466 __marker: PhantomData,
467 }
468 }
469}
470
471impl<T: OwnedTensorType> Clone for Tensor<T>
472where
473 T::H: HalideType,
474{
475 fn clone(&self) -> Tensor<T> {
476 let tensor_ptr = unsafe { Tensor_clone(self.tensor) };
477 Self {
478 tensor: tensor_ptr,
479 __marker: PhantomData,
480 }
481 }
482}
483
484#[derive(Clone, Copy)]
486#[repr(C)]
487pub struct TensorShape {
488 pub(crate) shape: [i32; 4],
489 pub(crate) size: usize,
490}
491
492impl From<mnn_sys::TensorShape> for TensorShape {
493 fn from(value: mnn_sys::TensorShape) -> Self {
494 Self {
495 shape: value.shape,
496 size: value.size,
497 }
498 }
499}
500
501impl From<TensorShape> for mnn_sys::TensorShape {
502 fn from(value: TensorShape) -> Self {
503 Self {
504 shape: value.shape,
505 size: value.size,
506 }
507 }
508}
509
510impl core::ops::Deref for TensorShape {
511 type Target = [i32];
512
513 fn deref(&self) -> &Self::Target {
514 &self.shape[..self.size]
515 }
516}
517
518impl core::ops::Index<usize> for TensorShape {
519 type Output = i32;
520
521 fn index(&self, index: usize) -> &Self::Output {
522 &self.shape[..self.size][index]
523 }
524}
525
526impl core::ops::IndexMut<usize> for TensorShape {
527 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
528 &mut self.shape[..self.size][index]
529 }
530}
531
532impl core::ops::DerefMut for TensorShape {
533 fn deref_mut(&mut self) -> &mut Self::Target {
534 &mut self.shape[..self.size]
535 }
536}
537
538impl core::fmt::Debug for TensorShape {
539 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
540 write!(f, "{:?}", &self.shape[..self.size])
541 }
542}
543
544pub trait AsTensorShape {
546 fn as_tensor_shape(&self) -> TensorShape;
548}
549
550impl<T: AsRef<[i32]>> AsTensorShape for T {
551 fn as_tensor_shape(&self) -> TensorShape {
552 let this = self.as_ref();
553 let size = std::cmp::min(this.len(), 4);
554 let mut shape = [1; 4];
555 shape[..size].copy_from_slice(&this[..size]);
556 TensorShape { shape, size }
557 }
558}
559
560impl AsTensorShape for TensorShape {
561 fn as_tensor_shape(&self) -> TensorShape {
562 *self
563 }
564}
565
566#[cfg(test)]
567mod as_tensor_shape_tests {
568 use super::AsTensorShape;
569 macro_rules! shape_test {
570 ($t:ty, $kind: expr, $value: expr) => {
571 eprintln!("Testing {} with {} shape", stringify!($t), $kind);
572 $value.as_tensor_shape();
573 };
574 }
575 #[test]
576 fn as_tensor_shape_test_vec() {
577 shape_test!(Vec<i32>, "small", vec![1, 2, 3]);
578 shape_test!(Vec<i32>, "large", vec![12, 23, 34, 45, 67]);
579 }
580 #[test]
581 fn as_tensor_shape_test_array() {
582 shape_test!([i32; 3], "small", [1, 2, 3]);
583 shape_test!([i32; 5], "large", [12, 23, 34, 45, 67]);
584 }
585 #[test]
586 fn as_tensor_shape_test_ref() {
587 shape_test!(&[i32], "small", &[1, 2, 3]);
588 shape_test!(&[i32], "large", &[12, 23, 34, 45, 67]);
589 }
590}
591
592#[cfg(test)]
593mod tensor_tests {
594 #[test]
595 #[should_panic]
596 fn unsafe_nullptr_tensor() {
597 unsafe {
598 super::Tensor::<super::Host<i32>>::from_ptr(core::ptr::null_mut());
599 }
600 }
601}
602
603impl<T: HostTensorType + RefTensorType> Tensor<T>
604where
605 T::H: HalideType,
606{
607 pub fn borrowed(shape: impl AsTensorShape, input: impl AsRef<[T::H]>) -> Self {
609 let shape = shape.as_tensor_shape();
610 let input = input.as_ref();
611 let tensor = unsafe {
612 Tensor_createWith(
613 shape.shape.as_ptr(),
614 shape.size,
615 halide_type_of::<T::H>(),
616 input.as_ptr().cast_mut().cast(),
617 DimensionType::Caffe.to_mnn_sys(),
618 )
619 };
620 debug_assert!(!tensor.is_null());
621 Self {
622 tensor,
623 __marker: PhantomData,
624 }
625 }
626
627 pub fn borrowed_mut(shape: impl AsTensorShape, mut input: impl AsMut<[T::H]>) -> Self {
629 let shape = shape.as_tensor_shape();
630 let input = input.as_mut();
631 let tensor = unsafe {
632 Tensor_createWith(
633 shape.shape.as_ptr(),
634 shape.size,
635 halide_type_of::<T::H>(),
636 input.as_mut_ptr().cast(),
637 DimensionType::Caffe.to_mnn_sys(),
638 )
639 };
640 debug_assert!(!tensor.is_null());
641 Self {
642 tensor,
643 __marker: PhantomData,
644 }
645 }
646}
647
648#[test]
649fn test_tensor_borrowed() {
650 let shape = [1, 2, 3];
651 let data = vec![1, 2, 3, 4, 5, 6];
652 let tensor = Tensor::<Ref<Host<i32>>>::borrowed(&shape, &data);
653 assert_eq!(tensor.shape().as_ref(), shape);
654 assert_eq!(tensor.host(), data.as_slice());
655}
656
657#[test]
658fn test_tensor_borrow_mut() {
659 let shape = [1, 2, 3];
660 let mut data = vec![1, 2, 3, 4, 5, 6];
661 let mut tensor = Tensor::<RefMut<Host<i32>>>::borrowed_mut(&shape, &mut data);
662 tensor.host_mut().fill(1);
663 assert_eq!(data, &[1, 1, 1, 1, 1, 1]);
664}