binarybana commented on a change in pull request #5527: URL: https://github.com/apache/incubator-tvm/pull/5527#discussion_r437744020
########## File path: rust/tvm-rt/src/function.rs ########## @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use lazy_static::lazy_static; +use std::convert::TryFrom; +use std::{ + collections::BTreeMap, + ffi::{CStr, CString}, + mem::{self, MaybeUninit}, + os::raw::{c_char, c_int}, + ptr, slice, str, + sync::Mutex, +}; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::errors::Error; + +use super::to_boxed_fn::ToBoxedFn; +use super::to_function::{ToFunction, Typed}; + +pub type Result<T> = std::result::Result<T, Error>; + +lazy_static! { + static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<String, Option<Function>>> = { + let mut out_size = 0 as c_int; + let mut names_ptr = ptr::null_mut() as *mut *const c_char; + check_call!(ffi::TVMFuncListGlobalNames( + &mut out_size as *mut _, + &mut names_ptr as *mut _, + )); + let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) }; + + let names_list: Vec<String> = + names_list + .iter() + .map(|&p| unsafe { CStr::from_ptr(p).to_str().unwrap().into() }) + .collect(); + + // println!("{:?}", &names_list); + + let names_list = names_list + .into_iter() + .map(|p| (p, None)) + .collect(); + + Mutex::new(names_list) + }; +} + +/// Wrapper around TVM function handle which includes `is_global` Review comment: Is this the same or different than `PackedFunc`? It'd be nice to clarify one way or another for ppl new to the project. ########## File path: rust/tvm-rt/src/context.rs ########## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::os::raw::c_void; +use std::ptr; + +use crate::errors::Error; + +use tvm_sys::ffi; + +pub use tvm_sys::context::*; + +trait ContextExt { + /// Checks whether the context exists or not. + fn exist(&self) -> bool; + fn sync(&self) -> Result<(), Error>; + fn max_threads_per_block(&self) -> isize; + fn warp_size(&self) -> isize; + fn max_shared_memory_per_block(&self) -> isize; + fn compute_version(&self) -> isize; + fn device_name(&self) -> isize; + fn max_clock_rate(&self) -> isize; + fn multi_processor_count(&self) -> isize; + fn max_thread_dimensions(&self) -> isize; +} + +macro_rules! impl_device_attrs { + ($(($attr_name:ident, $attr_kind:expr));+) => { + $( + fn $attr_name(&self) -> isize { + get_device_attr(self.device_type as i32, self.device_id as i32, 0) + .expect("should not fail") as isize + } + + )+ + }; +} + +crate::external! { + #[name("runtime.GetDeviceAttr")] + fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32; +} + +impl ContextExt for Context { + fn exist(&self) -> bool { Review comment: ```suggestion fn exists(&self) -> bool { ``` ########## File path: rust/macros/src/external.rs ########## @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use proc_macro2::Span; +use quote::quote; +use syn::parse::{Parse, ParseStream, Result}; + +use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; + +struct External { + tvm_name: String, + ident: Ident, + generics: Generics, + inputs: Vec<FnArg>, + ret_type: ReturnType, +} + +impl Parse for External { + fn parse(input: ParseStream) -> Result<Self> { + let method: TraitItemMethod = input.parse()?; + assert_eq!(method.attrs.len(), 1); + let sig = method.sig; + let tvm_name = method.attrs[0].parse_meta()?; + let tvm_name = match tvm_name { + Meta::List(meta_list) => { + let name = meta_list.path.get_ident().expect("name"); + assert_eq!(name.to_string(), "name".to_string()); + match meta_list.nested.first() { + Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(), + _ => panic!(), Review comment: Can you return `Err` here instead with some kind of breadcrumb message? Also 4 lines below. ########## File path: rust/macros/src/external.rs ########## @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use proc_macro2::Span; +use quote::quote; +use syn::parse::{Parse, ParseStream, Result}; + +use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; + +struct External { Review comment: Even though this is private, a module or struct level rustdoc as to what this is and why it's here would be helpful. ########## File path: rust/tvm-rt/README.md ########## @@ -0,0 +1,60 @@ +<!--- Licensed to the Apache Software Foundation (ASF) under one --> +<!--- or more contributor license agreements. See the NOTICE file --> +<!--- distributed with this work for additional information --> +<!--- regarding copyright ownership. The ASF licenses this file --> +<!--- to you under the Apache License, Version 2.0 (the --> +<!--- "License"); you may not use this file except in compliance --> +<!--- with the License. You may obtain a copy of the License at --> + +<!--- http://www.apache.org/licenses/LICENSE-2.0 --> + +<!--- Unless required by applicable law or agreed to in writing, --> +<!--- software distributed under the License is distributed on an --> +<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> +<!--- KIND, either express or implied. See the License for the --> +<!--- specific language governing permissions and limitations --> +<!--- under the License. --> + +# TVM Runtime Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime. +Currently this is tested on `1.42.0` and above. + +## What Does This Crate Offer? + +TVM is an end-to-end deep learning compiler which takes high level machine learning +models or tensor computations and lowers them into executable code for a variety +of heterogenous devices (e.g., CPU, GPU). + +This crate provides access to the APIs for manipulating runtime data structures, +as well as TVM's cross-language Object system which functions similarly to systems +such as COM, enabling cross-language interoperability. + +## Installations + +Please follow TVM [installation](https://tvm.apache.org/docs/install/index.html) instructions, +`export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +### Example of registering a cross-language closure. + +One can use `register!` macro to expose a Rust closure with arguments which implement `TryFrom<ArgValue>` +and return types which implement `Into<RetValue>`. Once registered with TVM these functions can be +accessed via Python or C++, or any other language which implements the packed function convention Review comment: ```suggestion accessed via Python or C++, or any other language which implements the TVM packed function convention ``` ########## File path: rust/tvm-rt/src/to_function.rs ########## @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::convert::{TryFrom, TryInto}; +use std::{ + mem::MaybeUninit, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use super::{function::Result, Function}; +use crate::errors::Error; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +/// A trait representing whether the function arguments +/// and return type can be assigned to a TVM packed function. +/// +/// By splitting the conversion to function into two traits +/// we are able to improve error reporting, by splitting the +/// conversion of inputs and outputs to this trait. +/// +/// And the implementation of it to `ToFunction`. +pub trait Typed<I, O> { + fn args(i: &[ArgValue<'static>]) -> Result<I>; + fn ret(o: O) -> RetValue; +} + +impl<F, O: Into<RetValue>> Typed<(), O> for F +where + F: Fn() -> O, +{ + fn args(_args: &[ArgValue<'static>]) -> Result<()> { + debug_assert!(_args.len() == 0); + Ok(()) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl<F, A, O: Into<RetValue>, E> Typed<(A,), O> for F +where + F: Fn(A) -> O, + Error: From<E>, + A: TryFrom<ArgValue<'static>, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A,)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + Ok((a,)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl<F, A, B, O: Into<RetValue>, E> Typed<(A, B), O> for F +where + F: Fn(A, B) -> O, + Error: From<E>, + A: TryFrom<ArgValue<'static>, Error = E>, + B: TryFrom<ArgValue<'static>, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> { + debug_assert!(args.len() == 2); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + Ok((a, b)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl<F, A, B, C, O: Into<RetValue>, E> Typed<(A, B, C), O> for F +where + F: Fn(A, B, C) -> O, + Error: From<E>, + A: TryFrom<ArgValue<'static>, Error = E>, + B: TryFrom<ArgValue<'static>, Error = E>, + C: TryFrom<ArgValue<'static>, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { + debug_assert!(args.len() == 3); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + let c: C = args[2].clone().try_into()?; + Ok((a, b, c)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +pub trait ToFunction<I, O>: Sized { + type Handle; + + fn into_raw(self) -> *mut Self::Handle; + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> + where + Self: Typed<I, O>; + + fn drop(handle: *mut Self::Handle); + + fn to_function(self) -> Function + where + Self: Typed<I, O>, + { + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + let resource_handle = self.into_raw(); + check_call!(ffi::TVMFuncCreateFromCFunc( + Some(Self::tvm_callback), + resource_handle as *mut _, + Some(Self::tvm_finalizer), + &mut fhandle as *mut _ + )); + Function::new(fhandle) + } + + /// The callback function which is wrapped converted by TVM + /// into a packed function stored in fhandle. + unsafe extern "C" fn tvm_callback( + args: *mut ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: ffi::TVMRetValueHandle, + fhandle: *mut c_void, + ) -> c_int + where + Self: Typed<I, O>, + { + // turning off the incorrect linter complaints + #![allow(unused_assignments, unused_unsafe)] + let len = num_args as usize; + let args_list = slice::from_raw_parts_mut(args, len); + let type_codes_list = slice::from_raw_parts_mut(type_codes, len); + let mut local_args: Vec<ArgValue> = Vec::new(); + let mut value = MaybeUninit::uninit().assume_init(); + let mut tcode = MaybeUninit::uninit().assume_init(); + let rust_fn = fhandle as *mut Self::Handle; + for i in 0..len { + value = args_list[i]; + println!("{:?}", value.v_handle); + tcode = type_codes_list[i]; + if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int + { + check_call!(ffi::TVMCbArgToReturn( + &mut value as *mut _, + &mut tcode as *mut _ + )); + println!("{:?}", value.v_handle); + } + let arg_value = ArgValue::from_tvm_value(value, tcode as u32); + println!("{:?}", arg_value); + local_args.push(arg_value); + } + + let rv = match Self::call(rust_fn, local_args.as_slice()) { + Ok(v) => v, + Err(msg) => { + crate::set_last_error(&msg); + return -1; + } + }; + + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); + let mut ret_type_code = ret_tcode as c_int; + check_call!(ffi::TVMCFuncSetReturn( + ret, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + 1 as c_int + )); + 0 + } + + /// The finalizer which is invoked when the packed function's + /// reference count is zero. + unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { + let handle = std::mem::transmute(fhandle); + Self::drop(handle) + } +} + +// /// A wrapper that is used to work around inference issues for bare functions. Review comment: Remove? ########## File path: rust/tvm-rt/src/ndarray.rs ########## @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! # use tvm_rt::{NDArray, Context, DataType}; +//! # use ndarray::{Array, ArrayD}; +//! # use std::str::FromStr; +//! use std::convert::TryFrom; +//! +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::convert::TryInto; +use std::ffi::c_void; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +use crate::errors::NDArrayError; + +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, ByteArray, Context, DataType}; + +use ndarray::{Array, ArrayD}; +use num_traits::Num; + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +/// +/// Wrapper around TVM array handle. +#[derive(Debug)] +pub enum NDArray { + Borrowed { handle: ffi::TVMArrayHandle }, + Owned { handle: *mut c_void }, +} + +impl NDArray { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { + NDArray::Borrowed { handle } + } + + pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { + NDArray::Owned { handle } + } + + pub fn as_dltensor(&self) -> &DLTensor { + unsafe { + match self { + NDArray::Borrowed { ref handle } => std::mem::transmute(*handle), + NDArray::Owned { ref handle } => std::mem::transmute(*handle), + } + } + } + + pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { + unsafe { + match self { + NDArray::Borrowed { ref handle } => std::mem::transmute(*handle), + NDArray::Owned { ref handle } => std::mem::transmute(*handle), + } + } + } + + pub fn is_view(&self) -> bool { + if let &NDArray::Borrowed { .. } = self { + true + } else { + false + } + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = self.as_dltensor(); + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option<usize> { + self.shape().map(|v| v.iter().product()) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> Context { + self.as_dltensor().ctx.into() + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> DataType { + self.as_dltensor().dtype.into() + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + self.as_dltensor() + .ndim + .try_into() + .expect("number of dimensions must always be positive") + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::<usize>(); + let strides_ptr = self.as_dltensor().strides as *const usize; + let slc = slice::from_raw_parts(strides_ptr, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result<bool, crate::errors::Error> { + Ok(match self.strides() { + None => true, + Some(strides) => { + // NDArrayError::MissingShape in case shape is not determined + self.shape() + .ok_or(NDArrayError::MissingShape)? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + self.as_dltensor().byte_offset as isize + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let mut shape = [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); + /// ``` + pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> { + if !self.shape().is_some() { + return Err(NDArrayError::EmptyArray); + } + let earr = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + Context::cpu(0), + self.dtype(), + ); + let target = self.copy_to_ndarray(earr)?; + let arr = target.as_dltensor(); + let sz = self.size().ok_or(NDArrayError::MissingShape)?; + let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`ByteArray`]. + pub fn to_bytearray(&self) -> Result<ByteArray, NDArrayError> { + let v = self.to_vec::<u8>()?; + Ok(ByteArray::from(v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2.0]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) { + check_call!(ffi::TVMArrayCopyFromBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + data.len() * mem::size_of::<T>() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, NDArrayError> { + if self.dtype() != target.dtype() { + return Err(NDArrayError::DataTypeMismatch { + expected: self.dtype(), + actual: target.dtype(), + }); + } + + check_call!(ffi::TVMArrayCopyFromTo( + self.as_raw_dltensor(), + target.as_raw_dltensor(), + ptr::null_mut() as ffi::TVMStreamHandle + )); + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray, NDArrayError> { + let tmp = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + *target, + self.dtype(), + ); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray<T: Num32 + Copy>( + rnd: &ArrayD<T>, + ctx: Context, + dtype: DataType, + ) -> Result<Self, NDArrayError> { + let shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + let dtype: tvm_sys::ffi::DLDataType = dtype.into(); + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::Borrowed { handle: handle } + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + } + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + }; + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if let &mut NDArray::Owned { .. } = self { + check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} Review comment: Do you want to use `pub (crate)` to prevent accidental leakage/exporting? ########## File path: rust/tvm-rt/src/context.rs ########## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::os::raw::c_void; +use std::ptr; + +use crate::errors::Error; + +use tvm_sys::ffi; + +pub use tvm_sys::context::*; + +trait ContextExt { + /// Checks whether the context exists or not. + fn exist(&self) -> bool; Review comment: ```suggestion fn exists(&self) -> bool; ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
