This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch cargo-build in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/cargo-build by this push: new 83ad4d4 Fix 83ad4d4 is described below commit 83ad4d443350cc499e8af1fc419d3a694cb5e1f7 Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Thu Nov 5 18:05:28 2020 -0800 Fix --- python/tvm/relay/backend/graph_runtime_factory.py | 2 +- python/tvm/relay/build_module.py | 5 ++--- rust/tvm-rt/src/map.rs | 12 ++++++++++++ rust/tvm-rt/src/module.rs | 16 ++++++++++++++++ rust/tvm-rt/src/to_function.rs | 1 + src/runtime/module.cc | 2 +- 6 files changed, 33 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 4c6ac47..3427a62 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -21,7 +21,7 @@ from tvm._ffi.registry import get_global_func from tvm.runtime import ndarray -class GraphRuntimeFactoryModule(object): +class GraphRuntimeFactoryModule: """Graph runtime factory module. This is a module of graph runtime factory diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index e93d654..7e32dea 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -187,9 +187,8 @@ class BuildModule(object): return ret @register_func("tvm.relay.build") -def build1(mod, target=None, target_host=None, params=None, mod_name="default"): - import pdb; pdb.set_trace() - return build(mod, target, target_host, params, mod_name) +def _rust_build_module(mod, target=None, target_host=None, params=None, mod_name="default"): + return build(mod, target, target_host, params, mod_name).module def build(mod, target=None, target_host=None, params=None, mod_name="default"): """Helper function that builds a Relay function to run on TVM graph diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 721fb1e..ab44e40 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -109,6 +109,18 @@ where let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?; oref.downcast() } + + pub fn empty() -> Self { + Self::from_iter(vec![].into_iter()) + } + + //(@jroesch): I don't think this is a correct implementation. + pub fn null() -> Self { + Map { + object: ObjectRef::null(), + _data: PhantomData, + } + } } pub struct IntoIter<K, V> { diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index c0822a5..18347da 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -30,6 +30,8 @@ use tvm_sys::ffi; use crate::errors::Error; use crate::{errors, function::Function}; +use crate::{String as TString}; +use crate::RetValue; const ENTRY_FUNC: &str = "__tvm_main__"; @@ -49,6 +51,9 @@ crate::external! { #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; + + #[name("runtime.ModuleSaveToFile")] + fn save_to_file(module: ffi::TVMModuleHandle, name: TString, fmt: TString); } impl Module { @@ -110,6 +115,10 @@ impl Module { Ok(module) } + pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> { + save_to_file(self.handle(), name.into(), fmt.into()) + } + /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); @@ -128,3 +137,10 @@ impl Drop for Module { check_call!(ffi::TVMModFree(self.handle)); } } + +// impl std::convert::TryFrom<RetValue> for Module { +// type Error = Error; +// fn try_from(ret_value: RetValue) -> Result<Module, Self::Error> { +// Ok(Module::new(ret_value.try_into()?)) +// } +// } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index affd81b..c5ede7d 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -255,6 +255,7 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); +impl_typed_and_to_function!(6; A, B, C, D, E, G); #[cfg(test)] mod tests { diff --git a/src/runtime/module.cc b/src/runtime/module.cc index ac2b60f..af5feab 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -175,7 +175,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") - .set_body_typed([](Module mod, std::string name, std::string fmt) { + .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) { mod->SaveToFile(name, fmt); });