This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch rust-tvm in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit fc6fac254d02f91dae146c81c618cd17f8bf9d3c Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Sat Jun 6 21:54:58 2020 -0700 Reworking errors and proc macros --- rust/macros/src/lib.rs | 7 +++++ rust/tvm-rt/src/errors.rs | 5 +++- rust/tvm-rt/src/function.rs | 53 ++++++++++-------------------------- rust/tvm-rt/src/ndarray.rs | 23 ++++++++++------ rust/tvm-rt/src/object/mod.rs | 2 +- rust/tvm-rt/src/object/object_ptr.rs | 2 +- rust/tvm-rt/src/to_function.rs | 2 +- rust/tvm-rt/src/value.rs | 6 +--- rust/tvm/src/ir/array.rs | 5 ++-- rust/tvm/src/lib.rs | 9 ++---- 10 files changed, 48 insertions(+), 66 deletions(-) diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs index e9ddc25..d0ac1ca 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/lib.rs @@ -18,6 +18,8 @@ */ use proc_macro::TokenStream; + +mod external; mod import_module; mod object; @@ -31,3 +33,8 @@ pub fn macro_impl(input: TokenStream) -> TokenStream { // let input = proc_macro2::TokenStream::from(input); TokenStream::from(object::macro_impl(input)) } + +#[proc_macro] +pub fn external(input: TokenStream) -> TokenStream { + external::macro_impl(input) +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index 41e873f..f081258 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -48,7 +48,10 @@ pub enum NDArrayError { #[error("a shape error occurred in the Rust ndarray library")] ShapeError(#[from] ndarray::ShapeError), #[error("Expected type `{expected}` but found `{actual}`")] - DataTypeMismatch { expected: DataType, actual: DataType } + DataTypeMismatch { + expected: DataType, + actual: DataType, + }, } #[derive(Debug, Error)] diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 17f5f6e..b0122ff 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -25,6 +25,9 @@ //! //! See the tests and examples repository for more examples. +use anyhow::Result; +use lazy_static::lazy_static; +use std::convert::TryFrom; use std::{ collections::BTreeMap, ffi::{CStr, CString}, @@ -33,9 +36,6 @@ use std::{ ptr, slice, str, sync::Mutex, }; -use std::convert::{TryFrom}; -use anyhow::Result; -use lazy_static::lazy_static; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -194,7 +194,10 @@ impl TryFrom<RetValue> for Function { fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> { match ret_value { RetValue::FuncHandle(handle) => Ok(Function::new(handle)), - _ => Err(Error::downcast(format!("{:?}", ret_value), "FunctionHandle")) + _ => Err(Error::downcast( + format!("{:?}", ret_value), + "FunctionHandle", + )), } } } @@ -211,7 +214,10 @@ impl<'a> TryFrom<ArgValue<'a>> for Function { fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error> { match arg_value { ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), - _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")), + _ => Err(Error::downcast( + format!("{:?}", arg_value), + "FunctionHandle", + )), } } } @@ -222,7 +228,10 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error> { match arg_value { ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), - _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")), + _ => Err(Error::downcast( + format!("{:?}", arg_value), + "FunctionHandle", + )), } } } @@ -286,38 +295,6 @@ where Ok(()) } -#[macro_export] -macro_rules! external_func_impl { - ($name:ident , $($ty_param:tt)* , ( $($arg:ident : $ty:ty),* ), $ret_type:ty, $ext_name:literal) => { - ::paste::item! { - #[allow(non_upper_case_globals)] - static [<global_ $name>]: ::once_cell::sync::Lazy<&'static $crate::Function> = - ::once_cell::sync::Lazy::new(|| { - $crate::Function::get($ext_name) - .expect(concat!("unable to load external function", stringify!($ext_name), "from TVM registry.")) - }); - } - - pub fn $name<$($ty_param),*>($($arg : $ty),*) -> anyhow::Result<$ret_type> w,* { - let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ $name>] }; - let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>> = func_ref.to_boxed_fn(); - let res: $ret_type = func_ref($($arg),*)?; - Ok(res) - } - } -} - - -#[macro_export] -macro_rules! external_func { - (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { - $crate::external_func_impl!($name, , ( $($arg : $ty),* ) , $ret_type, $ext_name); - }; - (fn $name:ident < $($ty_param:ident),* > ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { - $crate::external_func_impl!($name, $($ty_param:ident),* , ( $($arg : $ty),* ) , $ret_type, $ext_name); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index f97b3a4..593154d 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -47,9 +47,9 @@ //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx -use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; use std::convert::TryInto; use std::ffi::c_void; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; use crate::errors::NDArrayError; @@ -190,7 +190,9 @@ impl NDArray { /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); /// ``` pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> { - if self.shape().is_some() { return Err(NDArrayError::EmptyArray); } + if self.shape().is_some() { + return Err(NDArrayError::EmptyArray); + } let earr = NDArray::empty( self.shape().ok_or(NDArrayError::MissingShape)?, Context::cpu(0), @@ -241,11 +243,10 @@ impl NDArray { /// Copies the NDArray to another target NDArray. pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, NDArrayError> { if self.dtype() != target.dtype() { - return Err( - NDArrayError::DataTypeMismatch { - expected: self.dtype(), - actual: target.dtype() - }); + return Err(NDArrayError::DataTypeMismatch { + expected: self.dtype(), + actual: target.dtype(), + }); } check_call!(ffi::TVMArrayCopyFromTo( @@ -307,7 +308,9 @@ macro_rules! impl_from_ndarray_rustndarray { type Error = NDArrayError; fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> { - if nd.shape().is_some() { return Err(NDArrayError::MissingShape); } + if nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + } assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( &*nd.shape().ok_or(NDArrayError::MissingShape)?, @@ -320,7 +323,9 @@ macro_rules! impl_from_ndarray_rustndarray { type Error = NDArrayError; fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> { - if nd.shape().is_some() { return Err(NDArrayError::MissingShape) }; + if nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + }; assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( &*nd.shape().ok_or(NDArrayError::MissingShape)?, diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 32da18e..9dcf836 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -2,8 +2,8 @@ use std::convert::TryFrom; use std::convert::TryInto; use std::ffi::CString; -use crate::external_func; use crate::errors::Error; +use crate::external_func; use tvm_sys::{ArgValue, RetValue}; diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 8e91878..ead37e3 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -194,7 +194,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> { let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; optr.downcast() } - _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")) + _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), } } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index dac37c8..0527b0c 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -32,8 +32,8 @@ use std::{ ptr, slice, }; -use crate::errors::Error; use super::Function; +use crate::errors::Error; pub use tvm_sys::{ffi, ArgValue, RetValue}; diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index d9436b1..1812c0c 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -25,11 +25,7 @@ use std::convert::TryFrom; // use std::ffi::c_void; use crate::{ArgValue, Module, NDArray, RetValue}; -use tvm_sys::{ - errors::ValueDowncastError, - ffi::{TVMModuleHandle}, - try_downcast, -}; +use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; macro_rules! impl_handle_val { ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs index a426474..bd12252 100644 --- a/rust/tvm/src/ir/array.rs +++ b/rust/tvm/src/ir/array.rs @@ -1,14 +1,13 @@ -use std::convert::{TryFrom}; +use std::convert::TryFrom; use std::marker::PhantomData; use crate::runtime::object::{ObjectRef, ToObjectRef}; -use tvm_rt::RetValue; use tvm_rt::external_func; +use tvm_rt::RetValue; use anyhow::Result; - #[derive(Clone)] pub struct Array<T: ToObjectRef> { object: ObjectRef, diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index 9315f7c..64252a4 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -30,14 +30,9 @@ //! //! Checkout the `examples` repository for more details. -pub use crate::{ - errors::*, - function::Function, - module::Module, - ndarray::NDArray, -}; +pub use crate::{errors::*, function::Function, module::Module, ndarray::NDArray}; -pub use tvm_rt::{Context, DeviceType, DataType}; +pub use tvm_rt::{Context, DataType, DeviceType}; pub use tvm_rt::context; pub use tvm_rt::errors;