use crate::prelude::*;
use core::marker::PhantomData;
use mnn_sys::HalideType;
#[repr(transparent)]
pub struct RawTensor<'r> {
pub(crate) inner: *mut mnn_sys::Tensor,
pub(crate) __marker: PhantomData<&'r ()>,
}
impl RawTensor<'_> {
pub fn create_host_tensor_from_device(&self, copy_data: bool) -> RawTensor<'static> {
let tensor =
unsafe { mnn_sys::Tensor_createHostTensorFromDevice(self.inner, copy_data as i32) };
assert!(!tensor.is_null());
RawTensor {
inner: tensor,
__marker: PhantomData,
}
}
pub fn copy_from_host_tensor(&mut self, tensor: &RawTensor) -> Result<()> {
let ret = unsafe { mnn_sys::Tensor_copyFromHostTensor(self.inner, tensor.inner) };
crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
Ok(())
}
pub fn copy_to_host_tensor(&self, tensor: &mut RawTensor) -> Result<()> {
let ret = unsafe { mnn_sys::Tensor_copyToHostTensor(self.inner, tensor.inner) };
crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret));
Ok(())
}
pub fn shape(&self) -> crate::TensorShape {
unsafe { mnn_sys::Tensor_shape(self.inner) }.into()
}
pub fn get_dimension_type(&self) -> super::DimensionType {
debug_assert!(!self.inner.is_null());
From::from(unsafe { mnn_sys::Tensor_getDimensionType(self.inner) })
}
pub fn destroy(self) {
unsafe {
mnn_sys::Tensor_destroy(self.inner);
}
}
pub fn size(&self) -> usize {
unsafe { mnn_sys::Tensor_usize(self.inner) }
}
pub fn element_size(&self) -> usize {
unsafe { mnn_sys::Tensor_elementSize(self.inner) as usize }
}
pub fn dimensions(&self) -> usize {
unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize }
}
pub fn width(&self) -> u32 {
unsafe { mnn_sys::Tensor_width(self.inner) as u32 }
}
pub fn height(&self) -> u32 {
unsafe { mnn_sys::Tensor_height(self.inner) as u32 }
}
pub fn channel(&self) -> u32 {
unsafe { mnn_sys::Tensor_channel(self.inner) as u32 }
}
pub fn is_dynamic_unsized(&self) -> bool {
self.shape().as_ref().contains(&-1)
}
pub fn wait(&self, map_type: MapType, finish: bool) {
unsafe {
mnn_sys::Tensor_wait(self.inner, map_type, finish as i32);
}
}
pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void {
debug_assert!(!self.inner.is_null());
let data = mnn_sys::Tensor_host_mut(self.inner);
debug_assert!(!data.is_null());
data
}
pub unsafe fn unchecked_host_bytes(&mut self) -> &mut [u8] {
core::slice::from_raw_parts_mut(self.unchecked_host_ptr().cast(), self.size())
}
pub unsafe fn to_concrete<T: super::TensorType>(self) -> super::Tensor<T>
where
T::H: HalideType,
{
super::Tensor::from_ptr(self.inner)
}
pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self {
Self {
inner,
__marker: PhantomData,
}
}
}