This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch cargo-build in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit b8dcc35801e7f0f3c696e0d030b3bd0931127785 Author: Jared Roesch <roesch...@gmail.com> AuthorDate: Thu Oct 22 22:18:58 2020 -0700 WIP --- rust/tvm-macros/src/external.rs | 2 +- rust/tvm-macros/src/lib.rs | 3 +- rust/tvm-macros/src/object.rs | 23 ++++++ rust/tvm-rt/src/object/mod.rs | 9 +-- rust/tvm-rt/src/object/object_ptr.rs | 16 ++++ rust/tvm-rt/src/string.rs | 1 + rust/tvm-rt/src/value.rs | 1 - rust/tvm-sys/src/datatype.rs | 4 + rust/tvm/src/ir/module.rs | 152 ++++++++++++++++++++++++++++++++--- rust/tvm/src/ir/relay/mod.rs | 36 +++------ rust/tvm/src/ir/tir.rs | 14 ++++ 11 files changed, 218 insertions(+), 43 deletions(-) diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 44a242c..51a389b 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -21,7 +21,7 @@ use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; -use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; +use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type}; struct ExternalItem { attrs: Vec<Attribute>, diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index 32f2839..e563a57 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -30,7 +30,8 @@ pub fn import_module(input: TokenStream) -> TokenStream { import_module::macro_impl(input) } -#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +#[proc_macro_error] +#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))] pub fn macro_impl(input: TokenStream) -> TokenStream { // let input = proc_macro2::TokenStream::from(input); TokenStream::from(object::macro_impl(input)) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index ff72d6a..7e6a934 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -36,6 +36,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { .map(attr_to_str) .expect("Failed to get type_key"); + let derive = get_attr(&derive_input, "no_derive").map(|_| false).unwrap_or(true); + let ref_id = get_attr(&derive_input, "ref_name") .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) .unwrap_or_else(|| { @@ -185,5 +187,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { expanded.extend(base_tokens); + if derive { + let derives = quote! { + impl std::hash::Hash for #ref_id { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + self.0.hash(state) + } + } + + impl std::cmp::PartialEq for #ref_id { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } + + impl std::cmp::Eq for #ref_id {} + }; + + + expanded.extend(derives); + } + TokenStream::from(expanded) } diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index e48c017..7e6107d 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -90,12 +90,7 @@ external! { #[name("ir.DebugPrint")] pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef; + fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef; + fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; } - -// external! { -// #[name("ir.TextPrinter")] -// fn as_text(object: ObjectRef) -> CString; -// } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 77254d2..a923506 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -342,6 +342,22 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> { } } +impl<T: IsObject> std::hash::Hash for ObjectPtr<T> { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap()) + } +} + +impl<T: IsObject> PartialEq for ObjectPtr<T> { + fn eq(&self, other: &Self) -> bool { + let lhs = ObjectRef(Some(self.clone().upcast())); + let rhs = ObjectRef(Some(other.clone().upcast())); + super::structural_equal(lhs, rhs, false, false).unwrap() + } +} + +impl<T: IsObject> Eq for ObjectPtr<T> {} + #[cfg(test)] mod tests { use super::{Object, ObjectPtr}; diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 6ff24be..e9a76d2 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -28,6 +28,7 @@ use tvm_macros::Object; #[derive(Object)] #[ref_name = "String"] #[type_key = "runtime.String"] +#[no_derive] pub struct StringObj { base: Object, data: *const u8, diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index c49944d..b8cd190 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -22,7 +22,6 @@ //! `RetValue` is the owned version of `TVMPODValue`. use std::convert::TryFrom; -// use std::ffi::c_void; use crate::{ArgValue, Module, RetValue}; use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs index 8050d93..5f7e0c3 100644 --- a/rust/tvm-sys/src/datatype.rs +++ b/rust/tvm-sys/src/datatype.rs @@ -83,6 +83,10 @@ impl DataType { DataType::new(DL_FLOAT_CODE, bits, lanes) } + pub const fn float32() -> DataType { + Self::float(32, 1) + } + pub const fn uint(bits: u8, lanes: u16) -> DataType { DataType::new(DL_UINT_CODE, bits, lanes) } diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 3b60b0c..db32ce2 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -17,6 +17,7 @@ * under the License. */ use std::io::Result as IOResult; +use std::iter::FromIterator; use std::path::Path; use thiserror::Error; @@ -33,8 +34,9 @@ use super::function::BaseFunc; use super::source_map::SourceMap; use super::{ty::GlobalTypeVar, relay}; -// TODO(@jroesch): define type +use tvm_macros::Object; +// TODO(@jroesch): define type type TypeData = ObjectRef; type GlobalTypeVar = ObjectRef; @@ -64,9 +66,11 @@ external! { fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; #[name("parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; + #[name("ir.IRModule")] + fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule; // Module methods #[name("ir.Module_Add")] - fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); + fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule; #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -78,15 +82,15 @@ external! { #[name("ir.Module_Lookup_str")] fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; #[name("ir.Module_GetGlobalTypeVars")] - fn module_get_global_type_vars() -> Array<GlobalTypeVar>; + fn module_get_global_type_vars(module: IRModule) -> Array<GlobalTypeVar>; #[name("ir.Module_ContainGlobalVar")] - fn module_contains_global_var(name: TVMString) -> bool; + fn module_contains_global_var(module: IRModule, name: TVMString) -> bool; #[name("ir.Module_ContainGlobalTypeVar")] - fn module_contains_global_type_var(name: TVMString) -> bool; + fn module_contains_global_type_var(module: IRModule, name: TVMString) -> bool; #[name("ir.Module_LookupDef")] - fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef; + fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData; #[name("ir.Module_LookupDef_str")] - fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef; + fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData; #[name("ir.Module_LookupTag")] fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor; #[name("ir.Module_FromExpr")] @@ -99,8 +103,12 @@ external! { // Note: we don't expose update here as update is going to be removed. - impl IRModule { + pub fn new<F, T>(funcs: F, types: T) -> Result<IRModule> + where F: IntoIterator<Item=(GlobalVar, BaseFunc)>, T: IntoIterator<Item=(GlobalTypeVar, TypeData)> { + module_new(Map::from_iter(funcs), Map::from_iter(types)) + } + pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule> where N: Into<TVMString>, @@ -119,6 +127,13 @@ impl IRModule { Ok(module) } + pub fn add( + &mut self, + var: GlobalVar, + func: BaseFunc) -> Result<IRModule> { + module_add(self.clone(), var, func, true) + } + pub fn add_def( &mut self, type_name: GlobalTypeVar, @@ -146,10 +161,127 @@ impl IRModule { { module_lookup_str(self.clone(), name.into()) } + + pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> { + module_get_global_type_vars(self.clone()) + } + + pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> { + module_contains_global_var(self.clone(), name.into()) + } + + pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> { + module_contains_global_type_var(self.clone(), name.into()) + } + + pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> { + module_lookup_def(self.clone(), global) + } + + pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> { + module_lookup_def_str(self.clone(), global) + } + + pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> { + module_lookup_tag(self.clone(), tag) + } + + pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> { + module_from_expr(expr, funcs, types) + } + + pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> { + module_import(self.clone(), path.into()) + } + + pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> { + module_import_from_std(self.clone(), path.into()) + } } #[cfg(test)] mod tests { - // #[test] - // fn + use std::collections::HashMap; + use super::relay::*; + use super::*; + use super::super::span::Span; + use tvm_rt::IsObjectRef; + + #[test] + fn test_module_add() -> anyhow::Result<()> { + let funcs = HashMap::<GlobalVar, BaseFunc>::new(); + let types = HashMap::<GlobalTypeVar, TypeData>::new(); + let mut module = IRModule::new(funcs, types)?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = Array::from_vec(vec![x.clone()])?; + let func = relay::Function::simple(params, x.upcast()).upcast(); + let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?; + // let lfunc = module.lookup_str("foo")?; + // let lfunc = lfunc.downcast::<relay::Function>()?; + // assert_eq!(lfunc.params.len(), 1); + Ok(()) + } + + #[test] + fn test_module_add_def() { + + } + + #[test] + fn test_get_global_var() { + + } + + #[test] + fn test_get_global_vars() { + + } + + #[test] + fn test_lookup() { + + } + + + // pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> { + // module_get_global_type_vars(self.clone()) + // } + + // pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> { + // module_contains_global_var(self.clone(), name.into()) + // } + + // pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> { + // module_contains_global_type_var(self.clone(), name.into()) + // } + + #[test] + fn test_lookup_def() { + + } + // pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> { + // module_lookup_def(self.clone(), global) + // } + + // pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> { + // module_lookup_def_str(self.clone(), global) + // } + + // pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> { + // module_lookup_tag(self.clone(), tag) + // } + + // pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> { + // module_from_expr(expr, funcs, types) + // } + + + // pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> { + // module_import(self.clone(), path.into()) + // } + + + // pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> { + // module_import_from_std(self.clone(), path.into()) + // } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 530b120..90b7a6a 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -16,11 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - -pub mod attrs; - -use std::hash::Hash; - use crate::runtime::array::Array; use crate::runtime::{object::*, String as TString}; @@ -29,11 +24,15 @@ use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::span::Span; use super::ty::{Type, TypeNode}; +use super::span::Span; use tvm_macros::Object; use tvm_rt::NDArray; pub use super::expr::{GlobalVar, GlobalVarNode}; +pub use crate::runtime::DataType; + +pub mod attrs; #[repr(C)] #[derive(Object)] @@ -58,20 +57,6 @@ impl ExprNode { } } -impl Hash for Expr { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - self.as_ptr().unwrap().ptr.hash(state) - } -} - -impl PartialEq for Expr { - fn eq(&self, other: &Self) -> bool { - self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr) - } -} - -impl Eq for Expr {} - #[repr(C)] #[derive(Object)] #[ref_name = "Id"] @@ -140,11 +125,11 @@ pub struct VarNode { } impl Var { - pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var { + pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var { let node = VarNode { base: ExprNode::base::<VarNode>(), vid: Id::new(name_hint.into()), - type_annotation, + type_annotation: type_annotation, }; Var(Some(ObjectPtr::new(node))) } @@ -153,8 +138,9 @@ impl Var { &self.vid.0.as_ref().unwrap().name_hint } - pub fn to_expr(self) -> Expr { - unsafe { Expr(std::mem::transmute(self.0)) } + pub fn static_tensor(name_hint: String, sh: Vec<i32>, dtype: DataType) -> Var { + let sh = Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap(); + Self::new(name_hint, super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), Span::null()) } } @@ -510,6 +496,10 @@ impl Function { }; Function(Some(ObjectPtr::new(node))) } + + pub fn simple(params: Array<Var>, body: Expr) -> Function { + Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap()) + } } #[cfg(test)] diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index 22d4e02..f07e854 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -47,6 +47,20 @@ macro_rules! define_node { // TODO(@jroesch): should move up to expr.rs to mirror TVM. define_node!(IntImm, "IntImm", "IntImm"; IntImmNode { value: i64 }); + +impl From<i32> for IntImm { + fn from(i: i32) -> IntImm { + IntImm::new(DataType::int(32, 1), i as i64) + } +} + +impl From<i32> for PrimExpr { + fn from(i: i32) -> PrimExpr { + use crate::runtime::IsObjectRef; + IntImm::from(i).upcast() + } +} + define_node!(Var, "Var", "tir.Var"; VarNode { name_hint: TVMString });