This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit b8dcc35801e7f0f3c696e0d030b3bd0931127785
Author: Jared Roesch <roesch...@gmail.com>
AuthorDate: Thu Oct 22 22:18:58 2020 -0700

    WIP
---
 rust/tvm-macros/src/external.rs      |   2 +-
 rust/tvm-macros/src/lib.rs           |   3 +-
 rust/tvm-macros/src/object.rs        |  23 ++++++
 rust/tvm-rt/src/object/mod.rs        |   9 +--
 rust/tvm-rt/src/object/object_ptr.rs |  16 ++++
 rust/tvm-rt/src/string.rs            |   1 +
 rust/tvm-rt/src/value.rs             |   1 -
 rust/tvm-sys/src/datatype.rs         |   4 +
 rust/tvm/src/ir/module.rs            | 152 ++++++++++++++++++++++++++++++++---
 rust/tvm/src/ir/relay/mod.rs         |  36 +++------
 rust/tvm/src/ir/tir.rs               |  14 ++++
 11 files changed, 218 insertions(+), 43 deletions(-)

diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 44a242c..51a389b 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,7 +21,7 @@ use proc_macro_error::abort;
 use quote::quote;
 use syn::parse::{Parse, ParseStream, Result};
 
-use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, 
Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, 
Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type};
 
 struct ExternalItem {
     attrs: Vec<Attribute>,
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index 32f2839..e563a57 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -30,7 +30,8 @@ pub fn import_module(input: TokenStream) -> TokenStream {
     import_module::macro_impl(input)
 }
 
-#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
+#[proc_macro_error]
+#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))]
 pub fn macro_impl(input: TokenStream) -> TokenStream {
     // let input = proc_macro2::TokenStream::from(input);
     TokenStream::from(object::macro_impl(input))
diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs
index ff72d6a..7e6a934 100644
--- a/rust/tvm-macros/src/object.rs
+++ b/rust/tvm-macros/src/object.rs
@@ -36,6 +36,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> 
TokenStream {
         .map(attr_to_str)
         .expect("Failed to get type_key");
 
+    let derive = get_attr(&derive_input, "no_derive").map(|_| 
false).unwrap_or(true);
+
     let ref_id = get_attr(&derive_input, "ref_name")
         .map(|a| Ident::new(attr_to_str(a).value().as_str(), 
Span::call_site()))
         .unwrap_or_else(|| {
@@ -185,5 +187,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> 
TokenStream {
 
     expanded.extend(base_tokens);
 
+    if derive {
+        let derives = quote! {
+            impl std::hash::Hash for #ref_id {
+                fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+                    self.0.hash(state)
+                }
+            }
+
+            impl std::cmp::PartialEq for #ref_id {
+                fn eq(&self, other: &Self) -> bool {
+                    self.0 == other.0
+                }
+            }
+
+            impl std::cmp::Eq for #ref_id {}
+        };
+
+
+        expanded.extend(derives);
+    }
+
     TokenStream::from(expanded)
 }
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index e48c017..7e6107d 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -90,12 +90,7 @@ external! {
     #[name("ir.DebugPrint")]
     pub fn debug_print(object: ObjectRef) -> CString;
     #[name("node.StructuralHash")]
-    fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
+    fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
     #[name("node.StructuralEqual")]
-    fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, 
map_free_vars: bool) -> ObjectRef;
+    fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, 
map_free_vars: bool) -> bool;
 }
-
-// external! {
-//     #[name("ir.TextPrinter")]
-//     fn as_text(object: ObjectRef) -> CString;
-// }
diff --git a/rust/tvm-rt/src/object/object_ptr.rs 
b/rust/tvm-rt/src/object/object_ptr.rs
index 77254d2..a923506 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -342,6 +342,22 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for 
ObjectPtr<T> {
     }
 }
 
+impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        
state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), 
false).unwrap())
+    }
+}
+
+impl<T: IsObject> PartialEq for ObjectPtr<T> {
+    fn eq(&self, other: &Self) -> bool {
+        let lhs = ObjectRef(Some(self.clone().upcast()));
+        let rhs = ObjectRef(Some(other.clone().upcast()));
+        super::structural_equal(lhs, rhs, false, false).unwrap()
+    }
+}
+
+impl<T: IsObject> Eq for ObjectPtr<T> {}
+
 #[cfg(test)]
 mod tests {
     use super::{Object, ObjectPtr};
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
index 6ff24be..e9a76d2 100644
--- a/rust/tvm-rt/src/string.rs
+++ b/rust/tvm-rt/src/string.rs
@@ -28,6 +28,7 @@ use tvm_macros::Object;
 #[derive(Object)]
 #[ref_name = "String"]
 #[type_key = "runtime.String"]
+#[no_derive]
 pub struct StringObj {
     base: Object,
     data: *const u8,
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
index c49944d..b8cd190 100644
--- a/rust/tvm-rt/src/value.rs
+++ b/rust/tvm-rt/src/value.rs
@@ -22,7 +22,6 @@
 //! `RetValue` is the owned version of `TVMPODValue`.
 
 use std::convert::TryFrom;
-// use std::ffi::c_void;
 
 use crate::{ArgValue, Module, RetValue};
 use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs
index 8050d93..5f7e0c3 100644
--- a/rust/tvm-sys/src/datatype.rs
+++ b/rust/tvm-sys/src/datatype.rs
@@ -83,6 +83,10 @@ impl DataType {
         DataType::new(DL_FLOAT_CODE, bits, lanes)
     }
 
+    pub const fn float32() -> DataType {
+        Self::float(32, 1)
+    }
+
     pub const fn uint(bits: u8, lanes: u16) -> DataType {
         DataType::new(DL_UINT_CODE, bits, lanes)
     }
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 3b60b0c..db32ce2 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -17,6 +17,7 @@
  * under the License.
  */
 use std::io::Result as IOResult;
+use std::iter::FromIterator;
 use std::path::Path;
 
 use thiserror::Error;
@@ -33,8 +34,9 @@ use super::function::BaseFunc;
 use super::source_map::SourceMap;
 use super::{ty::GlobalTypeVar, relay};
 
-// TODO(@jroesch): define type
+use tvm_macros::Object;
 
+// TODO(@jroesch): define type
 type TypeData = ObjectRef;
 type GlobalTypeVar = ObjectRef;
 
@@ -64,9 +66,11 @@ external! {
     fn parse_module(file_name: TVMString, source: TVMString) -> IRModule;
     #[name("parser.ParseExpr")]
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
+    #[name("ir.IRModule")]
+    fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, 
TypeData>) -> IRModule;
     // Module methods
     #[name("ir.Module_Add")]
-    fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, 
update: bool) -> ();
+    fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, 
update: bool) -> IRModule;
     #[name("ir.Module_AddDef")]
     fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: 
TypeData, update: bool) -> ();
     #[name("ir.Module_GetGlobalVar")]
@@ -78,15 +82,15 @@ external! {
     #[name("ir.Module_Lookup_str")]
     fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
     #[name("ir.Module_GetGlobalTypeVars")]
-    fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
+    fn module_get_global_type_vars(module: IRModule) -> Array<GlobalTypeVar>;
     #[name("ir.Module_ContainGlobalVar")]
-    fn module_contains_global_var(name: TVMString) -> bool;
+    fn module_contains_global_var(module: IRModule, name: TVMString) -> bool;
     #[name("ir.Module_ContainGlobalTypeVar")]
-    fn module_contains_global_type_var(name: TVMString) -> bool;
+    fn module_contains_global_type_var(module: IRModule, name: TVMString) -> 
bool;
     #[name("ir.Module_LookupDef")]
-    fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData;
     #[name("ir.Module_LookupDef_str")]
-    fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> 
TypeDef;
+    fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> 
TypeData;
     #[name("ir.Module_LookupTag")]
     fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
     #[name("ir.Module_FromExpr")]
@@ -99,8 +103,12 @@ external! {
 
 // Note: we don't expose update here as update is going to be removed.
 
-
 impl IRModule {
+    pub fn new<F, T>(funcs: F, types: T) -> Result<IRModule>
+    where F: IntoIterator<Item=(GlobalVar, BaseFunc)>, T: 
IntoIterator<Item=(GlobalTypeVar, TypeData)> {
+        module_new(Map::from_iter(funcs), Map::from_iter(types))
+    }
+
     pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
     where
         N: Into<TVMString>,
@@ -119,6 +127,13 @@ impl IRModule {
         Ok(module)
     }
 
+    pub fn add(
+        &mut self,
+        var: GlobalVar,
+        func: BaseFunc) -> Result<IRModule> {
+            module_add(self.clone(), var, func, true)
+        }
+
     pub fn add_def(
         &mut self,
         type_name: GlobalTypeVar,
@@ -146,10 +161,127 @@ impl IRModule {
     {
         module_lookup_str(self.clone(), name.into())
     }
+
+    pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
+        module_get_global_type_vars(self.clone())
+    }
+
+    pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> 
Result<bool> {
+        module_contains_global_var(self.clone(), name.into())
+    }
+
+    pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> 
Result<bool> {
+        module_contains_global_type_var(self.clone(), name.into())
+    }
+
+    pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
+        module_lookup_def(self.clone(), global)
+    }
+
+    pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
+        module_lookup_def_str(self.clone(), global)
+    }
+
+    pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
+        module_lookup_tag(self.clone(), tag)
+    }
+
+    pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, 
types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
+        module_from_expr(expr, funcs, types)
+    }
+
+    pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+        module_import(self.clone(), path.into())
+    }
+
+    pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> 
Result<()> {
+        module_import_from_std(self.clone(), path.into())
+    }
 }
 
 #[cfg(test)]
 mod tests {
-    // #[test]
-    // fn
+    use std::collections::HashMap;
+    use super::relay::*;
+    use super::*;
+    use super::super::span::Span;
+    use tvm_rt::IsObjectRef;
+
+    #[test]
+    fn test_module_add() -> anyhow::Result<()> {
+        let funcs = HashMap::<GlobalVar, BaseFunc>::new();
+        let types =  HashMap::<GlobalTypeVar, TypeData>::new();
+        let mut module = IRModule::new(funcs, types)?;
+        let x = Var::static_tensor("x".into(), vec![1, 1], 
DataType::float32());
+        let params = Array::from_vec(vec![x.clone()])?;
+        let func = relay::Function::simple(params, x.upcast()).upcast();
+        let module = module.add(GlobalVar::new("foo".into(), Span::null()), 
func)?;
+        // let lfunc = module.lookup_str("foo")?;
+        // let lfunc = lfunc.downcast::<relay::Function>()?;
+        // assert_eq!(lfunc.params.len(), 1);
+        Ok(())
+    }
+
+    #[test]
+    fn test_module_add_def() {
+
+    }
+
+    #[test]
+    fn test_get_global_var() {
+
+    }
+
+    #[test]
+    fn test_get_global_vars() {
+
+    }
+
+    #[test]
+    fn test_lookup() {
+
+    }
+
+
+    // pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
+    //     module_get_global_type_vars(self.clone())
+    // }
+
+    // pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> 
Result<bool> {
+    //     module_contains_global_var(self.clone(), name.into())
+    // }
+
+    // pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> 
Result<bool> {
+    //     module_contains_global_type_var(self.clone(), name.into())
+    // }
+
+    #[test]
+    fn test_lookup_def() {
+
+    }
+    // pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
+    //     module_lookup_def(self.clone(), global)
+    // }
+
+    // pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> 
{
+    //     module_lookup_def_str(self.clone(), global)
+    // }
+
+    // pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
+    //     module_lookup_tag(self.clone(), tag)
+    // }
+
+    // pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, 
types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
+    //     module_from_expr(expr, funcs, types)
+    // }
+
+
+    // pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+    //     module_import(self.clone(), path.into())
+    // }
+
+
+    // pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> 
Result<()> {
+    //     module_import_from_std(self.clone(), path.into())
+    // }
 }
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 530b120..90b7a6a 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -16,11 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
-pub mod attrs;
-
-use std::hash::Hash;
-
 use crate::runtime::array::Array;
 use crate::runtime::{object::*, String as TString};
 
@@ -29,11 +24,15 @@ use super::expr::BaseExprNode;
 use super::function::BaseFuncNode;
 use super::span::Span;
 use super::ty::{Type, TypeNode};
+use super::span::Span;
 
 use tvm_macros::Object;
 use tvm_rt::NDArray;
 
 pub use super::expr::{GlobalVar, GlobalVarNode};
+pub use crate::runtime::DataType;
+
+pub mod attrs;
 
 #[repr(C)]
 #[derive(Object)]
@@ -58,20 +57,6 @@ impl ExprNode {
     }
 }
 
-impl Hash for Expr {
-    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
-        self.as_ptr().unwrap().ptr.hash(state)
-    }
-}
-
-impl PartialEq for Expr {
-    fn eq(&self, other: &Self) -> bool {
-        self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr)
-    }
-}
-
-impl Eq for Expr {}
-
 #[repr(C)]
 #[derive(Object)]
 #[ref_name = "Id"]
@@ -140,11 +125,11 @@ pub struct VarNode {
 }
 
 impl Var {
-    pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> 
Var {
+    pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var {
         let node = VarNode {
             base: ExprNode::base::<VarNode>(),
             vid: Id::new(name_hint.into()),
-            type_annotation,
+            type_annotation: type_annotation,
         };
         Var(Some(ObjectPtr::new(node)))
     }
@@ -153,8 +138,9 @@ impl Var {
         &self.vid.0.as_ref().unwrap().name_hint
     }
 
-    pub fn to_expr(self) -> Expr {
-        unsafe { Expr(std::mem::transmute(self.0)) }
+    pub fn static_tensor(name_hint: String, sh: Vec<i32>, dtype: DataType) -> 
Var {
+        let sh = 
Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap();
+        Self::new(name_hint, super::ty::TensorType::new(sh, dtype, 
Span::null()).upcast(), Span::null())
     }
 }
 
@@ -510,6 +496,10 @@ impl Function {
         };
         Function(Some(ObjectPtr::new(node)))
     }
+
+    pub fn simple(params: Array<Var>, body: Expr) -> Function {
+        Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap())
+    }
 }
 
 #[cfg(test)]
diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs
index 22d4e02..f07e854 100644
--- a/rust/tvm/src/ir/tir.rs
+++ b/rust/tvm/src/ir/tir.rs
@@ -47,6 +47,20 @@ macro_rules! define_node {
 // TODO(@jroesch): should move up to expr.rs to mirror TVM.
 define_node!(IntImm, "IntImm", "IntImm";
              IntImmNode { value: i64 });
+
+impl From<i32> for IntImm {
+    fn from(i: i32) -> IntImm {
+        IntImm::new(DataType::int(32, 1), i as i64)
+    }
+}
+
+impl From<i32> for PrimExpr {
+    fn from(i: i32) -> PrimExpr {
+        use crate::runtime::IsObjectRef;
+        IntImm::from(i).upcast()
+    }
+}
+
 define_node!(Var, "Var", "tir.Var";
              VarNode { name_hint: TVMString });
 

Reply via email to