This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch rust-tvm in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit a44a379bb3b3f4fab505dce3520eeb97f230ac23 Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Sat May 30 01:07:46 2020 -0700 Refactor anyhow out of the rt layer --- rust/Cargo.toml | 3 +- rust/macros/src/object.rs | 8 ++--- rust/tvm-rt/src/errors.rs | 36 ++++++++++++++++---- rust/tvm-rt/src/function.rs | 66 +++++++++++++++++++++++++++++++++--- rust/tvm-rt/src/lib.rs | 5 +-- rust/tvm-rt/src/ndarray.rs | 66 +++++++++++++++++++----------------- rust/tvm-rt/src/object/mod.rs | 17 ++++------ rust/tvm-rt/src/object/object_ptr.rs | 30 +++++++++------- rust/tvm-rt/src/to_function.rs | 37 ++++++++++---------- rust/tvm-rt/src/value.rs | 5 ++- rust/tvm/src/ir/array.rs | 55 +++++++++++++++++------------- rust/tvm/src/lib.rs | 10 +----- rust/tvm/src/transform.rs | 2 +- 13 files changed, 211 insertions(+), 129 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6d3481b..e107104 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -29,5 +29,6 @@ members = [ "frontend/tests/callback", "frontend/examples/resnet", "tvm-sys", - "tvm-rt" + "tvm-rt", + "tvm", ] diff --git a/rust/macros/src/object.rs b/rust/macros/src/object.rs index 96a86dd..670d326 100644 --- a/rust/macros/src/object.rs +++ b/rust/macros/src/object.rs @@ -89,12 +89,12 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl std::convert::TryFrom<tvm_rt::RetValue> for #ref_id { - type Error = ::anyhow::Error; + type Error = tvm_rt::Error; fn try_from(ret_val: tvm_rt::RetValue) -> Result<#ref_id, Self::Error> { use std::convert::TryInto; let oref: ObjectRef = ret_val.try_into()?; - let ptr = oref.0.ok_or(anyhow::anyhow!("null ptr"))?; + let ptr = oref.0.ok_or(tvm_rt::Error::Null)?; let ptr = ptr.downcast::<#payload_id>()?; Ok(#ref_id(Some(ptr))) } @@ -122,7 +122,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl<'a> std::convert::TryFrom<tvm_rt::ArgValue<'a>> for #ref_id { - type Error = anyhow::Error; + type Error = tvm_rt::Error; fn try_from(arg_value: tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error> { use std::convert::TryInto; @@ -132,7 +132,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl<'a> std::convert::TryFrom<&tvm_rt::ArgValue<'a>> for #ref_id { - type Error = anyhow::Error; + type Error = tvm_rt::Error; fn try_from(arg_value: &tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error> { use std::convert::TryInto; diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index 77dbba7..41e873f 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -17,13 +17,10 @@ * under the License. */ +use crate::DataType; use thiserror::Error; #[derive(Debug, Error)] -#[error("Cannot convert from an empty array.")] -pub struct EmptyArrayError; - -#[derive(Debug, Error)] #[error("Handle `{name}` is null.")] pub struct NullHandleError { pub name: String, @@ -41,5 +38,32 @@ pub struct TypeMismatchError { } #[derive(Debug, Error)] -#[error("Missing NDArray shape.")] -pub struct MissingShapeError; +pub enum NDArrayError { + #[error("Missing NDArray shape.")] + MissingShape, + #[error("Cannot convert from an empty array.")] + EmptyArray, + #[error("Invalid datatype when attempting to convert ndarray.")] + InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError), + #[error("a shape error occurred in the Rust ndarray library")] + ShapeError(#[from] ndarray::ShapeError), + #[error("Expected type `{expected}` but found `{actual}`")] + DataTypeMismatch { expected: DataType, actual: DataType } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("{0}")] + Downcast(#[from] tvm_sys::errors::ValueDowncastError), + #[error("raw pointer passed across boundary was null")] + Null, +} + +impl Error { + pub fn downcast(actual_type: String, expected_type: &'static str) -> Error { + Self::Downcast(tvm_sys::errors::ValueDowncastError { + actual_type, + expected_type, + }) + } +} diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 2a5f446..17f5f6e 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -33,12 +33,14 @@ use std::{ ptr, slice, str, sync::Mutex, }; - +use std::convert::{TryFrom}; use anyhow::Result; use lazy_static::lazy_static; pub use tvm_sys::{ffi, ArgValue, RetValue}; +use crate::errors::Error; + use super::to_boxed_fn::ToBoxedFn; use super::to_function::{ToFunction, Typed}; @@ -180,6 +182,51 @@ impl Drop for Function { } } +impl From<Function> for RetValue { + fn from(func: Function) -> RetValue { + RetValue::FuncHandle(func.handle) + } +} + +impl TryFrom<RetValue> for Function { + type Error = Error; + + fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> { + match ret_value { + RetValue::FuncHandle(handle) => Ok(Function::new(handle)), + _ => Err(Error::downcast(format!("{:?}", ret_value), "FunctionHandle")) + } + } +} + +impl<'a> From<Function> for ArgValue<'a> { + fn from(func: Function) -> ArgValue<'a> { + ArgValue::FuncHandle(func.handle) + } +} + +impl<'a> TryFrom<ArgValue<'a>> for Function { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error> { + match arg_value { + ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), + _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")), + } + } +} + +impl<'a> TryFrom<&ArgValue<'a>> for Function { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error> { + match arg_value { + ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), + _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")), + } + } +} + /// Registers a Rust function with an arbitrary type signature in /// the TVM registry. /// @@ -240,8 +287,8 @@ where } #[macro_export] -macro_rules! external_func { - (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { +macro_rules! external_func_impl { + ($name:ident , $($ty_param:tt)* , ( $($arg:ident : $ty:ty),* ), $ret_type:ty, $ext_name:literal) => { ::paste::item! { #[allow(non_upper_case_globals)] static [<global_ $name>]: ::once_cell::sync::Lazy<&'static $crate::Function> = @@ -251,7 +298,7 @@ macro_rules! external_func { }); } - pub fn $name($($arg : $ty),*) -> Result<$ret_type, anyhow::Error> { + pub fn $name<$($ty_param),*>($($arg : $ty),*) -> anyhow::Result<$ret_type> w,* { let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ $name>] }; let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>> = func_ref.to_boxed_fn(); let res: $ret_type = func_ref($($arg),*)?; @@ -260,6 +307,17 @@ macro_rules! external_func { } } + +#[macro_export] +macro_rules! external_func { + (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { + $crate::external_func_impl!($name, , ( $($arg : $ty),* ) , $ret_type, $ext_name); + }; + (fn $name:ident < $($ty_param:ident),* > ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { + $crate::external_func_impl!($name, $($ty_param:ident),* , ( $($arg : $ty),* ) , $ret_type, $ext_name); + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 874d4fe..9b64eb6 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -44,8 +44,6 @@ use std::{ str, }; -use anyhow::Error; - pub use crate::{ context::{Context, DeviceType}, errors::*, @@ -57,7 +55,6 @@ pub use crate::{ pub use function::{ArgValue, RetValue}; pub use tvm_sys::byte_array::ByteArray; pub use tvm_sys::datatype::DataType; - use tvm_sys::ffi; // Macro to check the return call to TVM runtime shared library. @@ -80,7 +77,7 @@ pub fn get_last_error() -> &'static str { } } -pub(crate) fn set_last_error(err: &Error) { +pub(crate) fn set_last_error<E: std::error::Error>(err: &E) { let c_string = CString::new(err.to_string()).unwrap(); unsafe { ffi::TVMAPISetLastError(c_string.as_ptr()); diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 0adae8b..f97b3a4 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -48,16 +48,17 @@ //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; - -use crate::errors; -use anyhow::{bail, ensure, Result}; -use ndarray::{Array, ArrayD}; -use num_traits::Num; use std::convert::TryInto; use std::ffi::c_void; + +use crate::errors::NDArrayError; + use tvm_sys::ffi::DLTensor; use tvm_sys::{ffi, ByteArray, Context, DataType}; +use ndarray::{Array, ArrayD}; +use num_traits::Num; + /// See the [`module-level documentation`](../ndarray/index.html) for more details. /// /// Wrapper around TVM array handle. @@ -146,13 +147,13 @@ impl NDArray { } /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result<bool> { + pub fn is_contiguous(&self) -> anyhow::Result<bool> { Ok(match self.strides() { None => true, Some(strides) => { - // errors::MissingShapeError in case shape is not determined + // NDArrayError::MissingShape in case shape is not determined self.shape() - .ok_or(errors::MissingShapeError)? + .ok_or(NDArrayError::MissingShape)? .iter() .zip(strides) .rfold( @@ -188,16 +189,16 @@ impl NDArray { /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); /// ``` - pub fn to_vec<T>(&self) -> Result<Vec<T>> { - ensure!(self.shape().is_some(), errors::EmptyArrayError); + pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> { + if self.shape().is_some() { return Err(NDArrayError::EmptyArray); } let earr = NDArray::empty( - self.shape().ok_or(errors::MissingShapeError)?, + self.shape().ok_or(NDArrayError::MissingShape)?, Context::cpu(0), self.dtype(), ); let target = self.copy_to_ndarray(earr)?; let arr = target.as_dltensor(); - let sz = self.size().ok_or(errors::MissingShapeError)?; + let sz = self.size().ok_or(NDArrayError::MissingShape)?; let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>()); unsafe { v.as_mut_ptr() @@ -208,7 +209,7 @@ impl NDArray { } /// Converts the NDArray to [`ByteArray`]. - pub fn to_bytearray(&self) -> Result<ByteArray> { + pub fn to_bytearray(&self) -> Result<ByteArray, NDArrayError> { let v = self.to_vec::<u8>()?; Ok(ByteArray::from(v)) } @@ -238,16 +239,15 @@ impl NDArray { } /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> { + pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, NDArrayError> { if self.dtype() != target.dtype() { - bail!( - "{}", - errors::TypeMismatchError { - expected: self.dtype().to_string(), - actual: target.dtype().to_string(), - } - ); + return Err( + NDArrayError::DataTypeMismatch { + expected: self.dtype(), + actual: target.dtype() + }); } + check_call!(ffi::TVMArrayCopyFromTo( self.as_raw_dltensor(), target.as_raw_dltensor(), @@ -257,9 +257,9 @@ impl NDArray { } /// Copies the NDArray to a target context. - pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray> { + pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray, NDArrayError> { let tmp = NDArray::empty( - self.shape().ok_or(errors::MissingShapeError)?, + self.shape().ok_or(NDArrayError::MissingShape)?, *target, self.dtype(), ); @@ -272,7 +272,7 @@ impl NDArray { rnd: &ArrayD<T>, ctx: Context, dtype: DataType, - ) -> Result<Self> { + ) -> Result<Self, NDArrayError> { let shape = rnd.shape().to_vec(); let mut nd = NDArray::empty(&shape, ctx, dtype); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); @@ -304,24 +304,26 @@ impl NDArray { macro_rules! impl_from_ndarray_rustndarray { ($type:ty, $type_name:tt) => { impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { - type Error = anyhow::Error; - fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> { - ensure!(nd.shape().is_some(), errors::MissingShapeError); + type Error = NDArrayError; + + fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> { + if nd.shape().is_some() { return Err(NDArrayError::MissingShape); } assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( - &*nd.shape().ok_or(errors::MissingShapeError)?, + &*nd.shape().ok_or(NDArrayError::MissingShape)?, nd.to_vec::<$type>()?, )?) } } impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { - type Error = anyhow::Error; - fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> { - ensure!(nd.shape().is_some(), errors::MissingShapeError); + type Error = NDArrayError; + + fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> { + if nd.shape().is_some() { return Err(NDArrayError::MissingShape) }; assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( - &*nd.shape().ok_or(errors::MissingShapeError)?, + &*nd.shape().ok_or(NDArrayError::MissingShape)?, nd.to_vec::<$type>()?, )?) } diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 2ff9a1f..32da18e 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -1,7 +1,10 @@ -use crate::external_func; use std::convert::TryFrom; use std::convert::TryInto; use std::ffi::CString; + +use crate::external_func; +use crate::errors::Error; + use tvm_sys::{ArgValue, RetValue}; mod object_ptr; @@ -27,14 +30,8 @@ impl ToObjectRef for ObjectRef { } } -// impl<T: ToObjectRef> ToObjectRef for &T { -// fn to_object_ref(&self) -> ObjectRef { -// (*self).to_object_ref() -// } -// } - impl TryFrom<RetValue> for ObjectRef { - type Error = anyhow::Error; + type Error = Error; fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> { let optr = ret_val.try_into()?; @@ -54,7 +51,7 @@ impl From<ObjectRef> for RetValue { } impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef { - type Error = anyhow::Error; + type Error = Error; fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> { let optr = arg_value.try_into()?; @@ -63,7 +60,7 @@ impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef { } impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { - type Error = anyhow::Error; + type Error = Error; fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectRef, Self::Error> { // TODO(@jroesch): remove the clone diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index c716c05..8e91878 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -1,10 +1,12 @@ -use anyhow::Context; use std::convert::TryFrom; use std::ffi::CString; use std::ptr::NonNull; + use tvm_sys::ffi::{self, /* TVMObjectFree, */ TVMObjectRetain, TVMObjectTypeKey2Index}; use tvm_sys::{ArgValue, RetValue}; +use crate::errors::Error; + type Deleter<T> = unsafe extern "C" fn(object: *mut T) -> (); #[derive(Debug)] @@ -27,6 +29,7 @@ fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { parent_type_index, &mut is_derived )); + if is_derived == 0 { false } else { @@ -96,7 +99,6 @@ pub struct ObjectPtr<T> { impl ObjectPtr<Object> { fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> { - println!("{:?}", object_ptr); let non_null = NonNull::new(object_ptr); non_null.map(|ptr| ObjectPtr { ptr }) } @@ -144,7 +146,7 @@ impl<T: IsObject> ObjectPtr<T> { } } - pub fn downcast<U: IsObject>(&self) -> anyhow::Result<ObjectPtr<U>> { + pub fn downcast<U: IsObject>(&self) -> Result<ObjectPtr<U>, Error> { let child_index = Object::get_type_index::<U>(); let object_index = self.as_object().type_index; @@ -160,7 +162,7 @@ impl<T: IsObject> ObjectPtr<T> { ptr: self.ptr.cast(), }) } else { - Err(anyhow::anyhow!("failed to downcast to object subtype")) + Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } } } @@ -183,16 +185,16 @@ impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue { } impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> { - type Error = anyhow::Error; + type Error = Error; fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> { match ret_value { RetValue::ObjectHandle(handle) => { let handle: *mut Object = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; optr.downcast() } - _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")) } } } @@ -207,29 +209,31 @@ impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> { } impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> { - type Error = anyhow::Error; + type Error = Error; + fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> { match arg_value { ArgValue::ObjectHandle(handle) => { let handle = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; optr.downcast() } - _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } } } impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr<T> { - type Error = anyhow::Error; + type Error = Error; + fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> { match arg_value { ArgValue::ObjectHandle(handle) => { let handle = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; optr.downcast() } - _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 504ce3e..dac37c8 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -25,19 +25,18 @@ //! //! See the tests and examples repository for more examples. +use std::convert::{TryFrom, TryInto}; use std::{ mem::MaybeUninit, os::raw::{c_int, c_void}, ptr, slice, }; -use anyhow::Result; +use crate::errors::Error; +use super::Function; pub use tvm_sys::{ffi, ArgValue, RetValue}; -use super::Function; -use std::convert::{TryFrom, TryInto}; - /// A trait representing whether the function arguments /// and return type can be assigned to a TVM packed function. /// @@ -47,7 +46,7 @@ use std::convert::{TryFrom, TryInto}; /// /// And the implementation of it to `ToFunction`. pub trait Typed<I, O> { - fn args(i: &[ArgValue<'static>]) -> anyhow::Result<I>; + fn args(i: &[ArgValue<'static>]) -> Result<I, Error>; fn ret(o: O) -> RetValue; } @@ -55,7 +54,7 @@ impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>> for F where F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>, { - fn args(args: &[ArgValue<'static>]) -> anyhow::Result<&'a [ArgValue<'static>]> { + fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>], Error> { // this is BAD but just hacking for time being Ok(unsafe { std::mem::transmute(args) }) } @@ -69,7 +68,7 @@ impl<F, O: Into<RetValue>> Typed<(), O> for F where F: Fn() -> O, { - fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<()> { + fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<(), Error> { debug_assert!(_args.len() == 0); Ok(()) } @@ -79,13 +78,13 @@ where } } -impl<F, A, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,), O> for F +impl<F, A, O: Into<RetValue>, E> Typed<(A,), O> for F where F: Fn(A) -> O, - E: std::error::Error + Send + Sync + 'static, + Error: From<E>, A: TryFrom<ArgValue<'static>, Error = E>, { - fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,)> { + fn args(args: &[ArgValue<'static>]) -> Result<(A,), Error> { debug_assert!(args.len() == 1); let a: A = args[0].clone().try_into()?; Ok((a,)) @@ -96,14 +95,14 @@ where } } -impl<F, A, B, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B), O> for F +impl<F, A, B, O: Into<RetValue>, E> Typed<(A, B), O> for F where F: Fn(A, B) -> O, - E: std::error::Error + Send + Sync + 'static, + Error: From<E>, A: TryFrom<ArgValue<'static>, Error = E>, B: TryFrom<ArgValue<'static>, Error = E>, { - fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B)> { + fn args(args: &[ArgValue<'static>]) -> Result<(A, B), Error> { debug_assert!(args.len() == 2); let a: A = args[0].clone().try_into()?; let b: B = args[1].clone().try_into()?; @@ -115,15 +114,15 @@ where } } -impl<F, A, B, C, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B, C), O> for F +impl<F, A, B, C, O: Into<RetValue>, E> Typed<(A, B, C), O> for F where F: Fn(A, B, C) -> O, - E: std::error::Error + Send + Sync + 'static, + Error: From<E>, A: TryFrom<ArgValue<'static>, Error = E>, B: TryFrom<ArgValue<'static>, Error = E>, C: TryFrom<ArgValue<'static>, Error = E>, { - fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B, C)> { + fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C), Error> { debug_assert!(args.len() == 3); let a: A = args[0].clone().try_into()?; let b: B = args[1].clone().try_into()?; @@ -141,7 +140,7 @@ pub trait ToFunction<I, O>: Sized { fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue, Error> where Self: Typed<I, O>; @@ -280,7 +279,7 @@ where Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue> + fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue, Error> where F: Typed<(), O>, { @@ -303,7 +302,7 @@ macro_rules! to_function_instance { Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<($($param,)+), O> { + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue, Error> where F: Typed<($($param,)+), O> { // Ideally we shouldn't need to clone, probably doesn't really matter. let args = F::args(args)?; let out = unsafe { diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index a9355e0..d9436b1 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -24,10 +24,10 @@ use std::convert::TryFrom; // use std::ffi::c_void; -use crate::{ArgValue, Function, Module, NDArray, RetValue}; +use crate::{ArgValue, Module, NDArray, RetValue}; use tvm_sys::{ errors::ValueDowncastError, - ffi::{TVMFunctionHandle, TVMModuleHandle}, + ffi::{TVMModuleHandle}, try_downcast, }; @@ -74,7 +74,6 @@ macro_rules! impl_handle_val { }; } -impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new); impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); impl<'a> From<&'a NDArray> for ArgValue<'a> { diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs index f371497..a426474 100644 --- a/rust/tvm/src/ir/array.rs +++ b/rust/tvm/src/ir/array.rs @@ -1,46 +1,55 @@ -use crate::runtime::function::Builder; -use crate::runtime::object::{ObjectRef, ToObjectRef}; -use std::convert::{TryFrom, TryInto}; +use std::convert::{TryFrom}; use std::marker::PhantomData; -use tvm_sys::TVMRetValue; + +use crate::runtime::object::{ObjectRef, ToObjectRef}; + +use tvm_rt::RetValue; +use tvm_rt::external_func; use anyhow::Result; + #[derive(Clone)] pub struct Array<T: ToObjectRef> { object: ObjectRef, _data: PhantomData<T>, } +external_func! { + fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef as "ir.DebugPrint"; +} + impl<T: ToObjectRef> Array<T> { pub fn from_vec(data: Vec<T>) -> Result<Array<T>> { - let iter = data.iter().map(|element| element.to_object_ref()); + unimplemented!() + // let iter = data.iter().map(|element| element.to_object_ref()); - let array_data = Builder::default() - .get_function("node.Array") - .args(iter) - .invoke()? - .try_into()?; + // let array_data = Builder::default() + // .get_function("node.Array") + // .args(iter) + // .invoke()? + // .try_into()?; - Ok(Array { - object: array_data, - _data: PhantomData, - }) + // Ok(Array { + // object: array_data, + // _data: PhantomData, + // }) } pub fn get(&self, index: isize) -> Result<T> where - T: TryFrom<TVMRetValue, Error = anyhow::Error>, + T: TryFrom<RetValue, Error = anyhow::Error>, { - // TODO(@jroesch): why do we used a signed index here? - let element: T = Builder::default() - .get_function("node.ArrayGetItem") - .arg(self.object.clone()) - .arg(index) - .invoke()? - .try_into()?; + unimplemented!() + // // TODO(@jroesch): why do we used a signed index here? + // let element: T = Builder::default() + // .get_function("node.ArrayGetItem") + // .arg(self.object.clone()) + // .arg(index) + // .invoke()? + // .try_into()?; - Ok(element) + // Ok(element) } } // mod array_api { diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index b7cf796..9315f7c 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -31,21 +31,13 @@ //! Checkout the `examples` repository for more details. pub use crate::{ - context::{TVMContext, TVMDeviceType}, errors::*, function::Function, module::Module, ndarray::NDArray, }; -// TODO: refactor -pub use tvm_sys::{ - errors as common_errors, - ffi::{self, DLDataType, TVMByteArray}, - packed_func::{TVMArgValue, TVMRetValue}, -}; - -pub type DataType = DLDataType; +pub use tvm_rt::{Context, DeviceType, DataType}; pub use tvm_rt::context; pub use tvm_rt::errors; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index 3657d3b..a89ab87 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -37,5 +37,5 @@ impl PassInfo { } external_func! { - fn create_func_pass(func: &Function, pass_info: PassInfo) -> Pass as "relay._transform.MakeFunctionPass"; + fn create_func_pass(func: Function, pass_info: PassInfo) -> Pass as "relay._transform.MakeFunctionPass"; }