This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/cargo-build by this push:
     new 83ad4d4  Fix
83ad4d4 is described below

commit 83ad4d443350cc499e8af1fc419d3a694cb5e1f7
Author: Jared Roesch <jroe...@octoml.ai>
AuthorDate: Thu Nov 5 18:05:28 2020 -0800

    Fix
---
 python/tvm/relay/backend/graph_runtime_factory.py |  2 +-
 python/tvm/relay/build_module.py                  |  5 ++---
 rust/tvm-rt/src/map.rs                            | 12 ++++++++++++
 rust/tvm-rt/src/module.rs                         | 16 ++++++++++++++++
 rust/tvm-rt/src/to_function.rs                    |  1 +
 src/runtime/module.cc                             |  2 +-
 6 files changed, 33 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/backend/graph_runtime_factory.py 
b/python/tvm/relay/backend/graph_runtime_factory.py
index 4c6ac47..3427a62 100644
--- a/python/tvm/relay/backend/graph_runtime_factory.py
+++ b/python/tvm/relay/backend/graph_runtime_factory.py
@@ -21,7 +21,7 @@ from tvm._ffi.registry import get_global_func
 from tvm.runtime import ndarray
 
 
-class GraphRuntimeFactoryModule(object):
+class GraphRuntimeFactoryModule:
     """Graph runtime factory module.
     This is a module of graph runtime factory
 
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index e93d654..7e32dea 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -187,9 +187,8 @@ class BuildModule(object):
         return ret
 
 @register_func("tvm.relay.build")
-def build1(mod, target=None, target_host=None, params=None, 
mod_name="default"):
-    import pdb; pdb.set_trace()
-    return build(mod, target, target_host, params, mod_name)
+def _rust_build_module(mod, target=None, target_host=None, params=None, 
mod_name="default"):
+    return build(mod, target, target_host, params, mod_name).module
 
 def build(mod, target=None, target_host=None, params=None, mod_name="default"):
     """Helper function that builds a Relay function to run on TVM graph
diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs
index 721fb1e..ab44e40 100644
--- a/rust/tvm-rt/src/map.rs
+++ b/rust/tvm-rt/src/map.rs
@@ -109,6 +109,18 @@ where
         let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?;
         oref.downcast()
     }
+
+    pub fn empty() -> Self {
+        Self::from_iter(vec![].into_iter())
+    }
+
+    //(@jroesch): I don't think this is a correct implementation.
+    pub fn null() -> Self {
+        Map {
+            object: ObjectRef::null(),
+            _data: PhantomData,
+        }
+    }
 }
 
 pub struct IntoIter<K, V> {
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index c0822a5..18347da 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -30,6 +30,8 @@ use tvm_sys::ffi;
 
 use crate::errors::Error;
 use crate::{errors, function::Function};
+use crate::{String as TString};
+use crate::RetValue;
 
 const ENTRY_FUNC: &str = "__tvm_main__";
 
@@ -49,6 +51,9 @@ crate::external! {
 
     #[name("runtime.ModuleLoadFromFile")]
     fn load_from_file(file_name: CString, format: CString) -> Module;
+
+    #[name("runtime.ModuleSaveToFile")]
+    fn save_to_file(module: ffi::TVMModuleHandle, name: TString, fmt: TString);
 }
 
 impl Module {
@@ -110,6 +115,10 @@ impl Module {
         Ok(module)
     }
 
+    pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> 
{
+        save_to_file(self.handle(), name.into(), fmt.into())
+    }
+
     /// Checks if a target device is enabled for a module.
     pub fn enabled(&self, target: &str) -> bool {
         let target = CString::new(target).unwrap();
@@ -128,3 +137,10 @@ impl Drop for Module {
         check_call!(ffi::TVMModFree(self.handle));
     }
 }
+
+// impl std::convert::TryFrom<RetValue> for Module {
+//     type Error = Error;
+//     fn try_from(ret_value: RetValue) -> Result<Module, Self::Error> {
+//         Ok(Module::new(ret_value.try_into()?))
+//     }
+// }
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index affd81b..c5ede7d 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -255,6 +255,7 @@ impl_typed_and_to_function!(2; A, B);
 impl_typed_and_to_function!(3; A, B, C);
 impl_typed_and_to_function!(4; A, B, C, D);
 impl_typed_and_to_function!(5; A, B, C, D, E);
+impl_typed_and_to_function!(6; A, B, C, D, E, G);
 
 #[cfg(test)]
 mod tests {
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index ac2b60f..af5feab 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -175,7 +175,7 @@ 
TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
 
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
 
 TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
-    .set_body_typed([](Module mod, std::string name, std::string fmt) {
+    .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) {
       mod->SaveToFile(name, fmt);
     });
 

Reply via email to