This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch rust-tvm
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit fc6fac254d02f91dae146c81c618cd17f8bf9d3c
Author: Jared Roesch <jroe...@octoml.ai>
AuthorDate: Sat Jun 6 21:54:58 2020 -0700

    Reworking errors and proc macros
---
 rust/macros/src/lib.rs               |  7 +++++
 rust/tvm-rt/src/errors.rs            |  5 +++-
 rust/tvm-rt/src/function.rs          | 53 ++++++++++--------------------------
 rust/tvm-rt/src/ndarray.rs           | 23 ++++++++++------
 rust/tvm-rt/src/object/mod.rs        |  2 +-
 rust/tvm-rt/src/object/object_ptr.rs |  2 +-
 rust/tvm-rt/src/to_function.rs       |  2 +-
 rust/tvm-rt/src/value.rs             |  6 +---
 rust/tvm/src/ir/array.rs             |  5 ++--
 rust/tvm/src/lib.rs                  |  9 ++----
 10 files changed, 48 insertions(+), 66 deletions(-)

diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
index e9ddc25..d0ac1ca 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/lib.rs
@@ -18,6 +18,8 @@
  */
 
 use proc_macro::TokenStream;
+
+mod external;
 mod import_module;
 mod object;
 
@@ -31,3 +33,8 @@ pub fn macro_impl(input: TokenStream) -> TokenStream {
     // let input = proc_macro2::TokenStream::from(input);
     TokenStream::from(object::macro_impl(input))
 }
+
+#[proc_macro]
+pub fn external(input: TokenStream) -> TokenStream {
+    external::macro_impl(input)
+}
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 41e873f..f081258 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -48,7 +48,10 @@ pub enum NDArrayError {
     #[error("a shape error occurred in the Rust ndarray library")]
     ShapeError(#[from] ndarray::ShapeError),
     #[error("Expected type `{expected}` but found `{actual}`")]
-    DataTypeMismatch { expected: DataType, actual: DataType }
+    DataTypeMismatch {
+        expected: DataType,
+        actual: DataType,
+    },
 }
 
 #[derive(Debug, Error)]
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 17f5f6e..b0122ff 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -25,6 +25,9 @@
 //!
 //! See the tests and examples repository for more examples.
 
+use anyhow::Result;
+use lazy_static::lazy_static;
+use std::convert::TryFrom;
 use std::{
     collections::BTreeMap,
     ffi::{CStr, CString},
@@ -33,9 +36,6 @@ use std::{
     ptr, slice, str,
     sync::Mutex,
 };
-use std::convert::{TryFrom};
-use anyhow::Result;
-use lazy_static::lazy_static;
 
 pub use tvm_sys::{ffi, ArgValue, RetValue};
 
@@ -194,7 +194,10 @@ impl TryFrom<RetValue> for Function {
     fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> {
         match ret_value {
             RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
-            _ => Err(Error::downcast(format!("{:?}", ret_value), 
"FunctionHandle"))
+            _ => Err(Error::downcast(
+                format!("{:?}", ret_value),
+                "FunctionHandle",
+            )),
         }
     }
 }
@@ -211,7 +214,10 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
     fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error> {
         match arg_value {
             ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
-            _ => Err(Error::downcast(format!("{:?}", arg_value), 
"FunctionHandle")),
+            _ => Err(Error::downcast(
+                format!("{:?}", arg_value),
+                "FunctionHandle",
+            )),
         }
     }
 }
@@ -222,7 +228,10 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
     fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error> {
         match arg_value {
             ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
-            _ => Err(Error::downcast(format!("{:?}", arg_value), 
"FunctionHandle")),
+            _ => Err(Error::downcast(
+                format!("{:?}", arg_value),
+                "FunctionHandle",
+            )),
         }
     }
 }
@@ -286,38 +295,6 @@ where
     Ok(())
 }
 
-#[macro_export]
-macro_rules! external_func_impl {
-    ($name:ident , $($ty_param:tt)* , ( $($arg:ident : $ty:ty),* ), 
$ret_type:ty, $ext_name:literal) => {
-        ::paste::item! {
-            #[allow(non_upper_case_globals)]
-            static [<global_ $name>]: ::once_cell::sync::Lazy<&'static 
$crate::Function> =
-            ::once_cell::sync::Lazy::new(|| {
-                $crate::Function::get($ext_name)
-                .expect(concat!("unable to load external function", 
stringify!($ext_name), "from TVM registry."))
-            });
-        }
-
-        pub fn $name<$($ty_param),*>($($arg : $ty),*) -> 
anyhow::Result<$ret_type> w,* {
-            let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ 
$name>] };
-            let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>> = 
func_ref.to_boxed_fn();
-            let res: $ret_type = func_ref($($arg),*)?;
-            Ok(res)
-        }
-    }
-}
-
-
-#[macro_export]
-macro_rules! external_func {
-    (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as 
$ext_name:literal;) => {
-        $crate::external_func_impl!($name, , ( $($arg : $ty),* ) , $ret_type, 
$ext_name);
-    };
-    (fn $name:ident < $($ty_param:ident),* > ( $($arg:ident : $ty:ty),* ) -> 
$ret_type:ty as $ext_name:literal;) => {
-        $crate::external_func_impl!($name, $($ty_param:ident),* , ( $($arg : 
$ty),* ) , $ret_type, $ext_name);
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index f97b3a4..593154d 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -47,9 +47,9 @@
 //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
 //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
 
-use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
 use std::convert::TryInto;
 use std::ffi::c_void;
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
 
 use crate::errors::NDArrayError;
 
@@ -190,7 +190,9 @@ impl NDArray {
     /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
     /// ```
     pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> {
-        if self.shape().is_some() { return Err(NDArrayError::EmptyArray); }
+        if self.shape().is_some() {
+            return Err(NDArrayError::EmptyArray);
+        }
         let earr = NDArray::empty(
             self.shape().ok_or(NDArrayError::MissingShape)?,
             Context::cpu(0),
@@ -241,11 +243,10 @@ impl NDArray {
     /// Copies the NDArray to another target NDArray.
     pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, 
NDArrayError> {
         if self.dtype() != target.dtype() {
-            return Err(
-                NDArrayError::DataTypeMismatch {
-                    expected: self.dtype(),
-                    actual: target.dtype()
-                });
+            return Err(NDArrayError::DataTypeMismatch {
+                expected: self.dtype(),
+                actual: target.dtype(),
+            });
         }
 
         check_call!(ffi::TVMArrayCopyFromTo(
@@ -307,7 +308,9 @@ macro_rules! impl_from_ndarray_rustndarray {
             type Error = NDArrayError;
 
             fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
-                if nd.shape().is_some() { return 
Err(NDArrayError::MissingShape); }
+                if nd.shape().is_some() {
+                    return Err(NDArrayError::MissingShape);
+                }
                 assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type 
mismatch");
                 Ok(Array::from_shape_vec(
                     &*nd.shape().ok_or(NDArrayError::MissingShape)?,
@@ -320,7 +323,9 @@ macro_rules! impl_from_ndarray_rustndarray {
             type Error = NDArrayError;
 
             fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, 
Self::Error> {
-                if nd.shape().is_some() { return 
Err(NDArrayError::MissingShape) };
+                if nd.shape().is_some() {
+                    return Err(NDArrayError::MissingShape);
+                };
                 assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type 
mismatch");
                 Ok(Array::from_shape_vec(
                     &*nd.shape().ok_or(NDArrayError::MissingShape)?,
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 32da18e..9dcf836 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -2,8 +2,8 @@ use std::convert::TryFrom;
 use std::convert::TryInto;
 use std::ffi::CString;
 
-use crate::external_func;
 use crate::errors::Error;
+use crate::external_func;
 
 use tvm_sys::{ArgValue, RetValue};
 
diff --git a/rust/tvm-rt/src/object/object_ptr.rs 
b/rust/tvm-rt/src/object/object_ptr.rs
index 8e91878..ead37e3 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -194,7 +194,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
                 let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
                 optr.downcast()
             }
-            _ => Err(Error::downcast(format!("{:?}", ret_value), 
"ObjectHandle"))
+            _ => Err(Error::downcast(format!("{:?}", ret_value), 
"ObjectHandle")),
         }
     }
 }
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index dac37c8..0527b0c 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -32,8 +32,8 @@ use std::{
     ptr, slice,
 };
 
-use crate::errors::Error;
 use super::Function;
+use crate::errors::Error;
 
 pub use tvm_sys::{ffi, ArgValue, RetValue};
 
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
index d9436b1..1812c0c 100644
--- a/rust/tvm-rt/src/value.rs
+++ b/rust/tvm-rt/src/value.rs
@@ -25,11 +25,7 @@ use std::convert::TryFrom;
 // use std::ffi::c_void;
 
 use crate::{ArgValue, Module, NDArray, RetValue};
-use tvm_sys::{
-    errors::ValueDowncastError,
-    ffi::{TVMModuleHandle},
-    try_downcast,
-};
+use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
 
 macro_rules! impl_handle_val {
     ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs
index a426474..bd12252 100644
--- a/rust/tvm/src/ir/array.rs
+++ b/rust/tvm/src/ir/array.rs
@@ -1,14 +1,13 @@
-use std::convert::{TryFrom};
+use std::convert::TryFrom;
 use std::marker::PhantomData;
 
 use crate::runtime::object::{ObjectRef, ToObjectRef};
 
-use tvm_rt::RetValue;
 use tvm_rt::external_func;
+use tvm_rt::RetValue;
 
 use anyhow::Result;
 
-
 #[derive(Clone)]
 pub struct Array<T: ToObjectRef> {
     object: ObjectRef,
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index 9315f7c..64252a4 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -30,14 +30,9 @@
 //!
 //! Checkout the `examples` repository for more details.
 
-pub use crate::{
-    errors::*,
-    function::Function,
-    module::Module,
-    ndarray::NDArray,
-};
+pub use crate::{errors::*, function::Function, module::Module, 
ndarray::NDArray};
 
-pub use tvm_rt::{Context, DeviceType, DataType};
+pub use tvm_rt::{Context, DataType, DeviceType};
 
 pub use tvm_rt::context;
 pub use tvm_rt::errors;

Reply via email to