This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch rust-tvm in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 0c55c39477979c75f2bf3e2e9974d90fde74fa26 Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Mon Jun 8 13:56:28 2020 -0700 Finish removing anyhow and work with new external! macro --- rust/tvm-rt/src/context.rs | 12 ++++++++---- rust/tvm-rt/src/errors.rs | 14 ++++++++------ rust/tvm-rt/src/function.rs | 12 ++++++------ rust/tvm-rt/src/module.rs | 10 +++++----- rust/tvm-rt/src/ndarray.rs | 2 +- rust/tvm-rt/src/to_boxed_fn.rs | 29 ++++++++++++++++------------- rust/tvm-rt/src/to_function.rs | 30 +++++++++++++++--------------- 7 files changed, 59 insertions(+), 50 deletions(-) diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs index 0c01d91..b1bdab5 100644 --- a/rust/tvm-rt/src/context.rs +++ b/rust/tvm-rt/src/context.rs @@ -1,13 +1,17 @@ -pub use tvm_sys::context::*; -use tvm_sys::ffi; use std::os::raw::c_void; use std::ptr; +use crate::errors::Error; + +use tvm_sys::ffi; + +pub use tvm_sys::context::*; + trait ContextExt { /// Checks whether the context exists or not. fn exist(&self) -> bool; - fn sync(&self) -> anyhow::Result<()>; + fn sync(&self) -> Result<(), Error>; fn max_threads_per_block(&self) -> isize; fn warp_size(&self) -> isize; fn max_shared_memory_per_block(&self) -> isize; @@ -44,7 +48,7 @@ impl ContextExt for Context { } /// Synchronize the context stream. - fn sync(&self) -> anyhow::Result<()> { + fn sync(&self) -> Result<(), Error> { check_call!(ffi::TVMSynchronize( self.device_type as i32, self.device_id as i32, diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index 414484d..197c875 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -21,12 +21,6 @@ use crate::DataType; use thiserror::Error; #[derive(Debug, Error)] -#[error("Handle `{name}` is null.")] -pub struct NullHandleError { - pub name: String, -} - -#[derive(Debug, Error)] #[error("Function was not set in `function::Builder`")] pub struct FunctionNotFoundError; @@ -62,6 +56,14 @@ pub enum Error { Null, #[error("failed to load module due to invalid path {0}")] ModuleLoadPath(String), + #[error("failed to convert String into CString due to embedded nul character")] + ToCString(#[from] std::ffi::NulError), + #[error("failed to convert CString into String")] + FromCString(#[from] std::ffi::IntoStringError), + #[error("Handle `{0}` is null.")] + NullHandle(String), + #[error("{0}")] + NDArray(#[from] NDArrayError), } impl Error { diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 4b34bc1..cca918a 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -138,7 +138,7 @@ impl Function { } /// Calls the function that created from `Builder`. - pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue, Error> { + pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue> { let num_args = arg_buf.len(); let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); @@ -192,7 +192,7 @@ impl From<Function> for RetValue { impl TryFrom<RetValue> for Function { type Error = Error; - fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> { + fn try_from(ret_value: RetValue) -> Result<Function> { match ret_value { RetValue::FuncHandle(handle) => Ok(Function::new(handle)), _ => Err(Error::downcast( @@ -212,7 +212,7 @@ impl<'a> From<Function> for ArgValue<'a> { impl<'a> TryFrom<ArgValue<'a>> for Function { type Error = Error; - fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error> { + fn try_from(arg_value: ArgValue<'a>) -> Result<Function> { match arg_value { ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), _ => Err(Error::downcast( @@ -226,7 +226,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function { impl<'a> TryFrom<&ArgValue<'a>> for Function { type Error = Error; - fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error> { + fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> { match arg_value { ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), _ => Err(Error::downcast( @@ -264,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { /// let ret = boxed_fn(10, 20, 30).unwrap(); /// assert_eq!(ret, 60); /// ``` -pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<(), Error> +pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()> where F: ToFunction<I, O>, F: Typed<I, O>, @@ -275,7 +275,7 @@ where /// Register a function with explicit control over whether to override an existing registration or not. /// /// See `register` for more details on how to use the registration API. -pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<(), Error> +pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<()> where F: ToFunction<I, O>, F: Typed<I, O>, diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index b8b56f4..c161af5 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -78,9 +78,9 @@ impl Module { )); if !fhandle.is_null() { - return Err(errors::NullHandleError { - name: name.into_string()?.to_string() - }) + return Err(errors::Error::NullHandle( + name.into_string()?.to_string() + )); } Ok(Function::new(fhandle)) @@ -98,13 +98,13 @@ impl Module { .extension() .unwrap_or_else(|| std::ffi::OsStr::new("")) .to_str() - .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display())) + .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))? )?; let cpath = CString::new( path.as_ref() .to_str() - .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display())) + .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))? )?; let module = load_from_file(cpath, ext)?; diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 593154d..9a17502 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -147,7 +147,7 @@ impl NDArray { } /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> anyhow::Result<bool> { + pub fn is_contiguous(&self) -> Result<bool, crate::errors::Error> { Ok(match self.strides() { None => true, Some(strides) => { diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs index d2dde67..12e4351 100644 --- a/rust/tvm-rt/src/to_boxed_fn.rs +++ b/rust/tvm-rt/src/to_boxed_fn.rs @@ -29,9 +29,7 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; use crate::{Module, errors}; -use super::function::Function; - -type Result<T> = std::result::Result<T, errors::Error>; +use super::function::{Function, Result}; pub trait ToBoxedFn { fn to_boxed_fn(func: &'static Function) -> Box<Self>; @@ -39,9 +37,10 @@ pub trait ToBoxedFn { use std::convert::{TryFrom, TryInto}; -impl<O> ToBoxedFn for dyn Fn() -> Result<O> +impl<E, O> ToBoxedFn for dyn Fn() -> Result<O> where - O: TryFrom<RetValue, Error = errors::Error>, + errors::Error: From<E>, + O: TryFrom<RetValue, Error = E>, { fn to_boxed_fn(func: &'static Function) -> Box<Self> { Box::new(move || { @@ -53,10 +52,11 @@ where } } -impl<A, O> ToBoxedFn for dyn Fn(A) -> Result<O> +impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O> where + errors::Error: From<E>, A: Into<ArgValue<'static>>, - O: TryFrom<RetValue, Error = errors::Error>, + O: TryFrom<RetValue, Error = E>, { fn to_boxed_fn(func: &'static Function) -> Box<Self> { Box::new(move |a: A| { @@ -69,11 +69,12 @@ where } } -impl<A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O> +impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O> where + errors::Error: From<E>, A: Into<ArgValue<'static>>, B: Into<ArgValue<'static>>, - O: TryFrom<RetValue, Error = errors::Error>, + O: TryFrom<RetValue, Error = E>, { fn to_boxed_fn(func: &'static Function) -> Box<Self> { Box::new(move |a: A, b: B| { @@ -87,12 +88,13 @@ where } } -impl<A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O> +impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O> where + errors::Error: From<E>, A: Into<ArgValue<'static>>, B: Into<ArgValue<'static>>, C: Into<ArgValue<'static>>, - O: TryFrom<RetValue, Error = errors::Error>, + O: TryFrom<RetValue, Error = E>, { fn to_boxed_fn(func: &'static Function) -> Box<Self> { Box::new(move |a: A, b: B, c: C| { @@ -107,13 +109,14 @@ where } } -impl<A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O> +impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O> where + errors::Error: From<E>, A: Into<ArgValue<'static>>, B: Into<ArgValue<'static>>, C: Into<ArgValue<'static>>, D: Into<ArgValue<'static>>, - O: TryFrom<RetValue, Error = errors::Error>, + O: TryFrom<RetValue, Error = E>, { fn to_boxed_fn(func: &'static Function) -> Box<Self> { Box::new(move |a: A, b: B, c: C, d: D| { diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 0527b0c..9d8065c 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -32,7 +32,7 @@ use std::{ ptr, slice, }; -use super::Function; +use super::{Function, function::Result}; use crate::errors::Error; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -46,20 +46,20 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// /// And the implementation of it to `ToFunction`. pub trait Typed<I, O> { - fn args(i: &[ArgValue<'static>]) -> Result<I, Error>; + fn args(i: &[ArgValue<'static>]) -> Result<I>; fn ret(o: O) -> RetValue; } -impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>> for F +impl<'a, F> Typed<&'a [ArgValue<'static>], Result<RetValue>> for F where - F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>, + F: Fn(&'a [ArgValue]) -> Result<RetValue>, { - fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>], Error> { + fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>]> { // this is BAD but just hacking for time being Ok(unsafe { std::mem::transmute(args) }) } - fn ret(ret_value: anyhow::Result<RetValue>) -> RetValue { + fn ret(ret_value: Result<RetValue>) -> RetValue { ret_value.unwrap() } } @@ -68,7 +68,7 @@ impl<F, O: Into<RetValue>> Typed<(), O> for F where F: Fn() -> O, { - fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<(), Error> { + fn args(_args: &[ArgValue<'static>]) -> Result<()> { debug_assert!(_args.len() == 0); Ok(()) } @@ -84,7 +84,7 @@ where Error: From<E>, A: TryFrom<ArgValue<'static>, Error = E>, { - fn args(args: &[ArgValue<'static>]) -> Result<(A,), Error> { + fn args(args: &[ArgValue<'static>]) -> Result<(A,)> { debug_assert!(args.len() == 1); let a: A = args[0].clone().try_into()?; Ok((a,)) @@ -102,7 +102,7 @@ where A: TryFrom<ArgValue<'static>, Error = E>, B: TryFrom<ArgValue<'static>, Error = E>, { - fn args(args: &[ArgValue<'static>]) -> Result<(A, B), Error> { + fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> { debug_assert!(args.len() == 2); let a: A = args[0].clone().try_into()?; let b: B = args[1].clone().try_into()?; @@ -122,7 +122,7 @@ where B: TryFrom<ArgValue<'static>, Error = E>, C: TryFrom<ArgValue<'static>, Error = E>, { - fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C), Error> { + fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { debug_assert!(args.len() == 3); let a: A = args[0].clone().try_into()?; let b: B = args[1].clone().try_into()?; @@ -140,7 +140,7 @@ pub trait ToFunction<I, O>: Sized { fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue, Error> + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where Self: Typed<I, O>; @@ -242,7 +242,7 @@ pub trait ToFunction<I, O>: Sized { // } // impl Typed<&[ArgValue<'static>], ()> for RawFunction { -// fn args(i: &[ArgValue<'static>]) -> anyhow::Result<&[ArgValue<'static>]> { +// fn args(i: &[ArgValue<'static>]) -> Result<&[ArgValue<'static>]> { // Ok(i) // } @@ -279,7 +279,7 @@ where Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue, Error> + fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<(), O>, { @@ -302,7 +302,7 @@ macro_rules! to_function_instance { Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue, Error> where F: Typed<($($param,)+), O> { + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<($($param,)+), O> { // Ideally we shouldn't need to clone, probably doesn't really matter. let args = F::args(args)?; let out = unsafe { @@ -338,7 +338,7 @@ mod tests { f.to_function() } - // fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> { + // fn func_args(args: &[ArgValue<'static>]) -> Result<RetValue> { // Ok(10.into()) // }