This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch rust-tvm-rt
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit d8d3147e304bdf5336acf22fb7616a9e59c2ad71
Author: Jared Roesch <jroe...@octoml.ai>
AuthorDate: Wed May 6 02:22:08 2020 -0700

    Add tvm-rt
---
 include/tvm/ir/expr.h                              |   5 +-
 python/tvm/runtime/object_generic.py               |   2 +-
 rust/macros/Cargo.toml                             |   4 +-
 rust/macros/src/{lib.rs => import_module.rs}       |  12 +-
 rust/macros/src/lib.rs                             | 124 +-----
 rust/macros/src/object.rs                          | 171 ++++++++
 rust/tvm-rt/.gitignore                             |   7 +
 rust/{macros/Cargo.toml => tvm-rt/.travis.yml}     |  24 +-
 rust/{macros => tvm-rt}/Cargo.toml                 |  30 +-
 rust/tvm-rt/README.md                              | 235 +++++++++++
 rust/{macros => tvm-rt/examples/resnet}/Cargo.toml |  23 +-
 rust/tvm-rt/examples/resnet/README.md              |  45 +++
 rust/tvm-rt/examples/resnet/build.rs               |  42 ++
 rust/tvm-rt/examples/resnet/src/build_resnet.py    | 134 +++++++
 rust/tvm-rt/examples/resnet/src/main.rs            | 160 ++++++++
 rust/tvm-rt/src/context.rs                         |  76 ++++
 rust/tvm-rt/src/errors.rs                          |  45 +++
 rust/tvm-rt/src/function.rs                        | 340 ++++++++++++++++
 rust/tvm-rt/src/lib.rs                             | 124 ++++++
 rust/tvm-rt/src/module.rs                          | 130 +++++++
 rust/tvm-rt/src/ndarray.rs                         | 431 +++++++++++++++++++++
 rust/tvm-rt/src/object/mod.rs                      |  99 +++++
 rust/tvm-rt/src/object/object_ptr.rs               | 283 ++++++++++++++
 rust/tvm-rt/src/string.rs                          |  72 ++++
 rust/tvm-rt/src/to_boxed_fn.rs                     | 222 +++++++++++
 rust/tvm-rt/src/to_function.rs                     | 377 ++++++++++++++++++
 rust/tvm-rt/src/value.rs                           | 166 ++++++++
 rust/tvm-rt/tests/test_ir.rs                       |  36 ++
 src/ir/expr.cc                                     |  11 +-
 src/printer/relay_text_printer.cc                  |  15 +-
 src/relay/transforms/to_cps.cc                     |   2 +-
 src/runtime/object.cc                              |  14 +
 src/runtime/object_internal.h                      |   9 +
 33 files changed, 3294 insertions(+), 176 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index fba35a9..82689bd 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -27,6 +27,7 @@
 #include <tvm/runtime/object.h>
 #include <tvm/node/node.h>
 #include <tvm/node/container.h>
+#include <tvm/runtime/container.h>
 #include <tvm/ir/span.h>
 #include <tvm/ir/type.h>
 #include <string>
@@ -36,6 +37,8 @@
 
 namespace tvm {
 
+using tvm::runtime::String;
+
 /*!
  * \brief Base type of all the expressions.
  * \sa Expr
@@ -189,7 +192,7 @@ class GlobalVar;
 class GlobalVarNode : public RelayExprNode {
  public:
   /*! \brief The name of the variable, this only acts as a hint. */
-  std::string name_hint;
+  String name_hint;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("name_hint", &name_hint);
diff --git a/python/tvm/runtime/object_generic.py 
b/python/tvm/runtime/object_generic.py
index cc21450..8f559ae 100644
--- a/python/tvm/runtime/object_generic.py
+++ b/python/tvm/runtime/object_generic.py
@@ -38,7 +38,7 @@ ObjectTypes = (ObjectBase, NDArrayBase, Module, 
ObjectRValueRef, PyNativeObject)
 
 
 def convert_to_object(value):
-    """Convert a python value to corresponding object type.
+    """Convert a Python value to corresponding object type.
 
     Parameters
     ----------
diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml
index 784b35e..7abc9ae 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/macros/Cargo.toml
@@ -32,5 +32,5 @@ proc-macro = true
 [dependencies]
 goblin = "0.0.24"
 proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
+quote = "^1.0"
+syn = { version = "1.0.17", features = ["full", "extra-traits"] }
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/import_module.rs
similarity index 92%
copy from rust/macros/src/lib.rs
copy to rust/macros/src/import_module.rs
index 9f28c74..6b059ae 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/import_module.rs
@@ -16,9 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
-extern crate proc_macro;
-
 use quote::quote;
 use std::{fs::File, io::Read};
 use syn::parse::{Parse, ParseStream, Result};
@@ -37,8 +34,7 @@ impl Parse for ImportModule {
     }
 }
 
-#[proc_macro]
-pub fn import_module(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
+pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let import_module_args = syn::parse_macro_input!(input as ImportModule);
 
     let manifest =
@@ -109,11 +105,11 @@ pub fn import_module(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream
     };
 
     let fns = quote! {
-        use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, 
FuncCallError};
+        use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError};
         #extern_fns
 
         #(
-            pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, 
FuncCallError> {
+            pub fn #fn_names(args: &[ArgValue]) -> Result<RetValue, 
FuncCallError> {
                 let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
                    .into_iter()
                    .map(|arg| {
@@ -125,7 +121,7 @@ pub fn import_module(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream
                     ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), 
values.len() as i32)
                 };
                 if exit_code == 0 {
-                    Ok(TVMRetValue::default())
+                    Ok(RetValue::default())
                 } else {
                     
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
                 }
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
index 9f28c74..e9ddc25 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/lib.rs
@@ -17,121 +17,17 @@
  * under the License.
  */
 
-extern crate proc_macro;
-
-use quote::quote;
-use std::{fs::File, io::Read};
-use syn::parse::{Parse, ParseStream, Result};
-use syn::LitStr;
-
-use std::path::PathBuf;
-
-struct ImportModule {
-    importing_file: LitStr,
-}
-
-impl Parse for ImportModule {
-    fn parse(input: ParseStream) -> Result<Self> {
-        let importing_file: LitStr = input.parse()?;
-        Ok(ImportModule { importing_file })
-    }
-}
+use proc_macro::TokenStream;
+mod import_module;
+mod object;
 
 #[proc_macro]
-pub fn import_module(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
-    let import_module_args = syn::parse_macro_input!(input as ImportModule);
-
-    let manifest =
-        std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be 
set by Cargo.");
-
-    let mut path = PathBuf::new();
-    path.push(manifest);
-    path = path.join(import_module_args.importing_file.value());
-
-    let mut fd = File::open(&path)
-        .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", 
path.display()));
-    let mut buffer = Vec::new();
-    fd.read_to_end(&mut buffer).unwrap();
-
-    let fn_names = match goblin::Object::parse(&buffer).unwrap() {
-        goblin::Object::Elf(elf) => elf
-            .syms
-            .iter()
-            .filter_map(|s| {
-                if s.st_type() == 0 || 
goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
-                    return None;
-                }
-                match elf.strtab.get(s.st_name) {
-                    Some(Ok(name)) if name != "" => {
-                        Some(syn::Ident::new(name, 
proc_macro2::Span::call_site()))
-                    }
-                    _ => None,
-                }
-            })
-            .collect::<Vec<_>>(),
-        goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
-            obj.symbols()
-                .filter_map(|s| match s {
-                    Ok((name, ref nlist))
-                        if nlist.is_global()
-                            && nlist.n_sect != 0
-                            && !name.ends_with("tvm_module_ctx") =>
-                    {
-                        Some(syn::Ident::new(
-                            if name.starts_with('_') {
-                                // Mach objects prepend a _ to globals.
-                                &name[1..]
-                            } else {
-                                &name
-                            },
-                            proc_macro2::Span::call_site(),
-                        ))
-                    }
-                    _ => None,
-                })
-                .collect::<Vec<_>>()
-        }
-        _ => panic!("Unsupported object format."),
-    };
-
-    let extern_fns = quote! {
-        mod ext {
-            extern "C" {
-                #(
-                    pub(super) fn #fn_names(
-                        args: *const tvm_runtime::ffi::TVMValue,
-                        type_codes: *const std::os::raw::c_int,
-                        num_args: std::os::raw::c_int
-                    ) -> std::os::raw::c_int;
-                )*
-            }
-        }
-    };
-
-    let fns = quote! {
-        use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, 
FuncCallError};
-        #extern_fns
-
-        #(
-            pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, 
FuncCallError> {
-                let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
-                   .into_iter()
-                   .map(|arg| {
-                       let (val, code) = arg.to_tvm_value();
-                       (val, code as i32)
-                   })
-                   .unzip();
-                let exit_code = unsafe {
-                    ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), 
values.len() as i32)
-                };
-                if exit_code == 0 {
-                    Ok(TVMRetValue::default())
-                } else {
-                    
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
-                }
-            }
-        )*
-    };
+pub fn import_module(input: TokenStream) -> TokenStream {
+    import_module::macro_impl(input)
+}
 
-    proc_macro::TokenStream::from(fns)
+#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
+pub fn macro_impl(input: TokenStream) -> TokenStream {
+    // let input = proc_macro2::TokenStream::from(input);
+    TokenStream::from(object::macro_impl(input))
 }
diff --git a/rust/macros/src/object.rs b/rust/macros/src/object.rs
new file mode 100644
index 0000000..96a86dd
--- /dev/null
+++ b/rust/macros/src/object.rs
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use proc_macro::TokenStream;
+use proc_macro2::Span;
+use quote::quote;
+use syn::DeriveInput;
+use syn::Ident;
+
+pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
+    let derive_input = syn::parse_macro_input!(input as DeriveInput);
+    let payload_id = derive_input.ident;
+
+    let mut type_key = None;
+    let mut ref_name = None;
+    let base = Some(Ident::new("base", Span::call_site()));
+
+    for attr in derive_input.attrs {
+        if attr.path.is_ident("type_key") {
+            type_key = Some(attr.parse_meta().expect("foo"))
+        }
+
+        if attr.path.is_ident("ref_name") {
+            ref_name = Some(attr.parse_meta().expect("foo"))
+        }
+    }
+
+    let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key {
+        match name_value.lit {
+            syn::Lit::Str(type_key) => type_key,
+            _ => panic!("foo"),
+        }
+    } else {
+        panic!("bar");
+    };
+
+    let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name {
+        match name_value.lit {
+            syn::Lit::Str(ref_name) => ref_name,
+            _ => panic!("foo"),
+        }
+    } else {
+        panic!("bar");
+    };
+
+    let ref_id = Ident::new(&ref_name.value(), Span::call_site());
+    let base = base.expect("should be present");
+
+    let expanded = quote! {
+        unsafe impl tvm_rt::object::IsObject for #payload_id {
+            const TYPE_KEY: &'static str = #type_key;
+
+            fn as_object<'s>(&'s self) -> &'s Object {
+                &self.#base.as_object()
+            }
+        }
+
+        #[derive(Clone)]
+        pub struct #ref_id(Option<tvm_rt::object::ObjectPtr<#payload_id>>);
+
+        impl tvm_rt::object::ToObjectRef for #ref_id {
+            fn to_object_ref(&self) -> ObjectRef {
+                ObjectRef(self.0.as_ref().map(|o| o.upcast()))
+            }
+        }
+
+        impl std::ops::Deref for #ref_id {
+            type Target = #payload_id;
+
+            fn deref(&self) -> &Self::Target {
+                self.0.as_ref().unwrap()
+            }
+        }
+
+        impl std::convert::TryFrom<tvm_rt::RetValue> for #ref_id {
+            type Error = ::anyhow::Error;
+
+            fn try_from(ret_val: tvm_rt::RetValue) -> Result<#ref_id, 
Self::Error> {
+                use std::convert::TryInto;
+                let oref: ObjectRef = ret_val.try_into()?;
+                let ptr = oref.0.ok_or(anyhow::anyhow!("null ptr"))?;
+                let ptr = ptr.downcast::<#payload_id>()?;
+                Ok(#ref_id(Some(ptr)))
+            }
+        }
+
+        impl<'a> From<#ref_id> for tvm_rt::ArgValue<'a> {
+            fn from(object_ref: #ref_id) -> tvm_rt::ArgValue<'a> {
+                use std::ffi::c_void;
+                let object_ptr = &object_ref.0;
+                match object_ptr {
+                    None => {
+                        tvm_rt::ArgValue::
+                            ObjectHandle(std::ptr::null::<c_void>() as *mut 
c_void)
+                    }
+                    Some(value) => value.clone().into()
+                }
+            }
+        }
+
+        impl<'a> From<&#ref_id> for tvm_rt::ArgValue<'a> {
+            fn from(object_ref: &#ref_id) -> tvm_rt::ArgValue<'a> {
+                let oref: #ref_id = object_ref.clone();
+                tvm_rt::ArgValue::<'a>::from(oref)
+            }
+        }
+
+        impl<'a> std::convert::TryFrom<tvm_rt::ArgValue<'a>> for #ref_id {
+            type Error = anyhow::Error;
+
+            fn try_from(arg_value: tvm_rt::ArgValue<'a>) -> Result<#ref_id, 
Self::Error> {
+                use std::convert::TryInto;
+                let optr = arg_value.try_into()?;
+                Ok(#ref_id(Some(optr)))
+            }
+        }
+
+        impl<'a> std::convert::TryFrom<&tvm_rt::ArgValue<'a>> for #ref_id {
+            type Error = anyhow::Error;
+
+            fn try_from(arg_value: &tvm_rt::ArgValue<'a>) -> Result<#ref_id, 
Self::Error> {
+                use std::convert::TryInto;
+                let optr = arg_value.try_into()?;
+                Ok(#ref_id(Some(optr)))
+            }
+        }
+
+        impl From<#ref_id> for tvm_rt::RetValue {
+            fn from(object_ref: #ref_id) -> tvm_rt::RetValue {
+                use std::ffi::c_void;
+                let object_ptr = &object_ref.0;
+                match object_ptr {
+                    None => {
+                        
tvm_rt::RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
+                    }
+                    Some(value) => value.clone().into()
+                }
+            }
+        }
+
+    };
+
+    TokenStream::from(expanded)
+}
+
+//  impl TryFrom<RetValue> for Var {
+//    type Error = anyhow::Error;
+
+//    fn try_from(ret_val: RetValue) -> Result<Var, Self::Error> {
+//       let oref: ObjectRef = ret_val.try_into()?;
+//       let var_ptr = oref.0.ok_or(anyhow!("null ptr"))?;
+//       let var_ptr = var_ptr.downcast::<VarNode>()?;
+//       Ok(Var(Some(var_ptr)))
+//    }
+// }
diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore
new file mode 100644
index 0000000..2430329
--- /dev/null
+++ b/rust/tvm-rt/.gitignore
@@ -0,0 +1,7 @@
+target
+**/*.rs.bk
+Cargo.lock
+/tests/basics/add_*
+/examples/resnet/deploy_*
+/examples/resnet/*.png
+/examples/resnet/synset.*
diff --git a/rust/macros/Cargo.toml b/rust/tvm-rt/.travis.yml
similarity index 67%
copy from rust/macros/Cargo.toml
copy to rust/tvm-rt/.travis.yml
index 784b35e..e963b7c 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/tvm-rt/.travis.yml
@@ -15,22 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
-[package]
-name = "tvm-macros"
-version = "0.1.1"
-license = "Apache-2.0"
-description = "Procedural macros of the TVM crate."
-repository = "https://github.com/apache/incubator-tvm";
-readme = "README.md"
-keywords = ["tvm"]
-authors = ["TVM Contributors"]
-edition = "2018"
-
-[lib]
-proc-macro = true
-
-[dependencies]
-goblin = "0.0.24"
-proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
+language: rust
+rust:
+  - nightly
+matrix:
+  fast_finish: true
diff --git a/rust/macros/Cargo.toml b/rust/tvm-rt/Cargo.toml
similarity index 65%
copy from rust/macros/Cargo.toml
copy to rust/tvm-rt/Cargo.toml
index 784b35e..417f256 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/tvm-rt/Cargo.toml
@@ -16,21 +16,29 @@
 # under the License.
 
 [package]
-name = "tvm-macros"
-version = "0.1.1"
+name = "tvm-rt"
+version = "0.1.0"
 license = "Apache-2.0"
-description = "Procedural macros of the TVM crate."
+description = "Rust bindings for the TVM runtime API."
 repository = "https://github.com/apache/incubator-tvm";
+homepage = "https://github.com/apache/incubator-tvm";
 readme = "README.md"
-keywords = ["tvm"]
+keywords = ["rust", "tvm"]
+categories = ["api-bindings", "science"]
 authors = ["TVM Contributors"]
 edition = "2018"
 
-[lib]
-proc-macro = true
-
 [dependencies]
-goblin = "0.0.24"
-proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
+thiserror = "^1.0"
+anyhow = "^1.0"
+lazy_static = "1.1"
+ndarray = "0.12"
+num-traits = "0.2"
+tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] }
+tvm-macros = { version = "0.1", path = "../macros" }
+paste = "0.1"
+mashup = "0.1"
+once_cell = "^1.3.1"
+
+[features]
+blas = ["ndarray/blas"]
diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md
new file mode 100644
index 0000000..fff3b56
--- /dev/null
+++ b/rust/tvm-rt/README.md
@@ -0,0 +1,235 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# TVM Runtime Frontend Support
+
+This crate provides an idiomatic Rust API for 
[TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this 
requires **Nightly Rust** and tested on `rustc 1.32.0-nightly`
+
+## What Does This Crate Offer?
+
+Here is a major workflow
+
+1. Train your **Deep Learning** model using any major framework such as 
[PyTorch](https://pytorch.org/), [Apache 
MXNet](https://mxnet.incubator.apache.org/) or 
[TensorFlow](https://www.tensorflow.org/)
+2. Use **TVM** to build optimized model artifacts on a supported context such 
as CPU, GPU, OpenCL and specialized accelerators.
+3. Deploy your models using **Rust** :heart:
+
+### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
+
+Please checkout [examples/resnet](examples/resnet) for the complete end-to-end 
example.
+
+Here's a Python snippet for downloading and building a pretrained Resnet18 via 
Apache MXNet and TVM
+
+```python
+block = get_model('resnet18_v1', pretrained=True)
+
+sym, params = relay.frontend.from_mxnet(block, shape_dict)
+# compile the model
+with relay.build_config(opt_level=opt_level):
+    graph, lib, params = relay.build(
+        net, target, params=params)
+# same the model artifacts
+lib.save(os.path.join(target_dir, "deploy_lib.o"))
+cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
+                [os.path.join(target_dir, "deploy_lib.o")])
+
+with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
+    fo.write(graph.json())
+with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
+    fo.write(relay.save_param_dict(params))
+```
+
+Now, we need to input the artifacts to create and run the *Graph Runtime* to 
detect our input cat image
+
+![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true)
+
+as demostrated in the following Rust snippet
+
+```rust
+    let graph = fs::read_to_string("deploy_graph.json")?;
+    // load the built module
+    let lib = Module::load(&Path::new("deploy_lib.so"))?;
+    // get the global TVM graph runtime function
+    let runtime_create_fn = Function::get("tvm.graph_runtime.create", 
true).unwrap();
+    let runtime_create_fn_ret = call_packed!(
+        runtime_create_fn,
+        &graph,
+        &lib,
+        &ctx.device_type,
+        &ctx.device_id
+    )?;
+    // get graph runtime module
+    let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?;
+    // get the registered `load_params` from runtime module
+    let ref load_param_fn = graph_runtime_module
+        .get_function("load_params", false)
+        .unwrap();
+    // parse parameters and convert to ByteArray
+    let params: Vec<u8> = fs::read("deploy_param.params")?;
+    let barr = ByteArray::from(&params);
+    // load the parameters
+    call_packed!(load_param_fn, &barr)?;
+    // get the set_input function
+    let ref set_input_fn = graph_runtime_module
+        .get_function("set_input", false)
+        .unwrap();
+
+    call_packed!(set_input_fn, "data", &input)?;
+    // get `run` function from runtime module
+    let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+    // execute the run function. Note that it has no argument
+    call_packed!(run_fn,)?;
+    // prepare to get the output
+    let output_shape = &mut [1, 1000];
+    let output = empty(output_shape, TVMContext::cpu(0), 
TVMType::from("float32"));
+    // get the `get_output` function from runtime module
+    let ref get_output_fn = graph_runtime_module
+        .get_function("get_output", false)
+        .unwrap();
+    // execute the get output function
+    call_packed!(get_output_fn, &0, &output)?;
+    // flatten the output as Vec<f32>
+    let output = output.to_vec::<f32>()?;
+```
+
+and the model correctly predicts the input image as **tiger cat**.
+
+## Installations
+
+Please follow TVM 
[installations](https://tvm.apache.org/docs/install/index.html), `export 
TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
+
+*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be 
added to your `PYTHONPATH` or it's automatic via an Anaconda environment when 
it is installed individually.
+
+## Supported TVM Functionalities
+
+### Use TVM to Generate Shared Library
+
+One can use the following Python snippet to generate `add_gpu.so` which add 
two vectors on GPU.
+
+```python
+import os
+import tvm
+from tvm import te
+from tvm.contrib import cc
+
+def test_add(target_dir):
+    if not tvm.runtime.enabled("cuda"):
+        print("skip {__file__} because cuda is not 
enabled...".format(__file__=__file__))
+        return
+    n = te.var("n")
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+    s = te.create_schedule(C.op)
+    bx, tx = s[C].split(C.op.axis[0], factor=64)
+    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", 
name="myadd")
+
+    fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
+    fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
+    cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
+            [os.path.join(target_dir, "add_gpu.o")])
+
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) != 2:
+        sys.exit(-1)
+    test_add(sys.argv[1])
+```
+
+### Run the Generated Shared Library
+
+The following code snippet demonstrates how to load and test the generated 
shared library (`add_gpu.so`) in Rust.
+
+```rust
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+    let shape = &mut [2];
+    let mut data = vec![3f32, 4.0];
+    let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+    arr.copy_from_buffer(data.as_mut_slice());
+    let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+    let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap();
+    let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap();
+    assert!(fadd.enabled("gpu"));
+    fadd.import_module(fadd_dep);
+    fadd.entry();
+    function::Builder::from(&mut fadd)
+        .arg(&arr)
+        .arg(&arr)
+        .set_output(&mut ret)?
+        .invoke()
+        .unwrap();
+
+    assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
+}
+```
+
+**Note:** it is required to instruct the `rustc` to link to the generated 
`add_gpu.so` in runtime, for example by
+`cargo:rustc-link-search=native=add_gpu`.
+
+See the tests and examples custom `build.rs` for more details.
+
+### Convert and Register a Rust Function as a TVM Packed Function
+
+One can use `register_global_func!` macro to convert and register a Rust
+function of type `fn(&[ArgValue]) -> Result<RetValue>` to a global TVM 
**packed function** as follows
+
+```rust
+#[macro_use]
+extern crate tvm_frontend as tvm;
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+    register_global_func! {
+        fn sum(args: &[ArgValue]) -> Result<RetValue> {
+            let mut ret = 0f32;
+            let shape = &mut [2];
+            for arg in args.iter() {
+                let e = empty(shape, TVMContext::cpu(0), 
TVMType::from("float32"));
+                let arg: NDArray = arg.try_into()?;
+                let arr = arg.copy_to_ndarray(e).unwrap();
+                let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap();
+                ret += rnd.scalar_sum();
+            }
+            let ret_val = RetValue::from(&ret);
+            Ok(ret_val)
+        }
+    }
+
+    let shape = &mut [2];
+    let mut data = vec![3f32, 4.0];
+    let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+    arr.copy_from_buffer(data.as_mut_slice());
+    let mut registered = function::Builder::default();
+    let ret: f64 = registered
+        .get_function("sum", true)
+        .arg(&arr)
+        .arg(&arr)
+        .invoke()
+        .unwrap()
+        .try_into()
+        .unwrap();
+
+    assert_eq!(ret, 14f64);
+}
+```
diff --git a/rust/macros/Cargo.toml b/rust/tvm-rt/examples/resnet/Cargo.toml
similarity index 74%
copy from rust/macros/Cargo.toml
copy to rust/tvm-rt/examples/resnet/Cargo.toml
index 784b35e..dbf59f3 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/tvm-rt/examples/resnet/Cargo.toml
@@ -16,21 +16,14 @@
 # under the License.
 
 [package]
-name = "tvm-macros"
-version = "0.1.1"
-license = "Apache-2.0"
-description = "Procedural macros of the TVM crate."
-repository = "https://github.com/apache/incubator-tvm";
-readme = "README.md"
-keywords = ["tvm"]
+name = "resnet"
+version = "0.0.0"
 authors = ["TVM Contributors"]
-edition = "2018"
-
-[lib]
-proc-macro = true
+license = "Apache-2.0"
+build = "build.rs"
 
 [dependencies]
-goblin = "0.0.24"
-proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
+ndarray = "0.12"
+tvm-frontend = { path = "../../" }
+image = "0.20"
+csv = "1.1"
diff --git a/rust/tvm-rt/examples/resnet/README.md 
b/rust/tvm-rt/examples/resnet/README.md
new file mode 100644
index 0000000..d6e32f7
--- /dev/null
+++ b/rust/tvm-rt/examples/resnet/README.md
@@ -0,0 +1,45 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+## Resnet example
+
+This end-to-end example shows how to:
+* build `Resnet 18` with `tvm` from Python
+* use the provided Rust frontend API to test for an input image
+
+To run the example with pretrained resnet weights, first `tvm`  and `mxnet` 
must be installed for the python build. To install mxnet for cpu, run `pip 
install mxnet`
+and to install `tvm` with `llvm` follow the [TVM installation 
guide](https://tvm.apache.org/docs/install/index.html).
+
+* **Build the example**: `cargo build
+
+To have a successful build, note that it is required to instruct Rust compiler 
to link to the compiled shared library, for example with
+`println!("cargo:rustc-link-search=native={}", build_path)`. See the 
`build.rs` for more details.
+
+* **Run the example**: `cargo run`
+
+Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` 
with
+
+```
+let output = Command::new("python")
+        .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
+        .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
+        .arg(&format!("--pretrained"))
+        .output()
+        .expect("Failed to execute command");
+```
+
+Otherwise, *random weights* are used, therefore, the prediction will be 
`limpkin, Aramus pictus`!
diff --git a/rust/tvm-rt/examples/resnet/build.rs 
b/rust/tvm-rt/examples/resnet/build.rs
new file mode 100644
index 0000000..b9a3c4c
--- /dev/null
+++ b/rust/tvm-rt/examples/resnet/build.rs
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{path::Path, process::Command};
+
+fn main() {
+    let output = Command::new("python3")
+        .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
+        .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
+        .output()
+        .expect("Failed to execute command");
+    assert!(
+        Path::new(&format!("{}/deploy_lib.o", 
env!("CARGO_MANIFEST_DIR"))).exists(),
+        "Could not prepare demo: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+    println!(
+        "cargo:rustc-link-search=native={}",
+        env!("CARGO_MANIFEST_DIR")
+    );
+}
diff --git a/rust/tvm-rt/examples/resnet/src/build_resnet.py 
b/rust/tvm-rt/examples/resnet/src/build_resnet.py
new file mode 100644
index 0000000..49c67bf
--- /dev/null
+++ b/rust/tvm-rt/examples/resnet/src/build_resnet.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import csv
+import logging
+from os import path as osp
+import sys
+
+import numpy as np
+
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay import testing
+from tvm.contrib import graph_runtime, cc
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - 
%(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+parser = argparse.ArgumentParser(description='Resnet build example')
+aa = parser.add_argument
+aa('--build-dir', type=str, required=True, help='directory to put the build 
artifacts')
+aa('--pretrained', action='store_true', help='use a pretrained resnet')
+aa('--batch-size', type=int, default=1, help='input image batch size')
+aa('--opt-level', type=int, default=3,
+   help='level of optimization. 0 is unoptimized and 3 is the highest level')
+aa('--target', type=str, default='llvm', help='target context for compilation')
+aa('--image-shape', type=str, default='3,224,224', help='input image 
dimensions')
+aa('--image-name', type=str, default='cat.png', help='name of input image to 
download')
+args = parser.parse_args()
+
+build_dir = args.build_dir
+batch_size = args.batch_size
+opt_level = args.opt_level
+target = tvm.target.create(args.target)
+image_shape = tuple(map(int, args.image_shape.split(",")))
+data_shape = (batch_size,) + image_shape
+
+def build(target_dir):
+    """ Compiles resnet18 with TVM"""
+    deploy_lib = osp.join(target_dir, 'deploy_lib.o')
+    if osp.exists(deploy_lib):
+        return
+
+    if args.pretrained:
+        # needs mxnet installed
+        from mxnet.gluon.model_zoo.vision import get_model
+
+        # if `--pretrained` is enabled, it downloads a pretrained
+        # resnet18 trained on imagenet1k dataset for image classification task
+        block = get_model('resnet18_v1', pretrained=True)
+        net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
+        # we want a probability so add a softmax operator
+        net = relay.Function(net.params, relay.nn.softmax(net.body),
+            None, net.type_params, net.attrs)
+    else:
+        # use random weights from relay.testing
+        net, params = relay.testing.resnet.get_workload(
+            num_layers=18, batch_size=batch_size, image_shape=image_shape)
+
+    # compile the model
+    with relay.build_config(opt_level=opt_level):
+            graph, lib, params = relay.build_module.build(net, target, 
params=params)
+
+    # save the model artifacts
+    lib.save(deploy_lib)
+    cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
+                    [osp.join(target_dir, "deploy_lib.o")])
+
+    with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
+        fo.write(graph)
+
+    with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
+        fo.write(relay.save_param_dict(params))
+
+def download_img_labels():
+    """ Download an image and imagenet1k class labels for test"""
+    from mxnet.gluon.utils import download
+
+    img_name = 'cat.png'
+    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
+                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
+                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
+                      'imagenet1000_clsid_to_human.txt'])
+    synset_name = 'synset.txt'
+    
download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', 
img_name)
+    download(synset_url, synset_name)
+
+    with open(synset_name) as fin:
+        synset = eval(fin.read())
+
+    with open("synset.csv", "w") as fout:
+        w = csv.writer(fout)
+        w.writerows(synset.items())
+
+def test_build(build_dir):
+    """ Sanity check with random input"""
+    graph = open(osp.join(build_dir, "deploy_graph.json")).read()
+    lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so"))
+    params = bytearray(open(osp.join(build_dir,"deploy_param.params"), 
"rb").read())
+    input_data = 
tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
+    ctx = tvm.cpu()
+    module = graph_runtime.create(graph, lib, ctx)
+    module.load_params(params)
+    module.run(data=input_data)
+    out = module.get_output(0).asnumpy()
+
+
+if __name__ == '__main__':
+    logger.info("building the model")
+    build(build_dir)
+    logger.info("build was successful")
+    logger.info("test the build artifacts")
+    test_build(build_dir)
+    logger.info("test was successful")
+    if args.pretrained:
+        download_img_labels()
+        logger.info("image and synset downloads are successful")
diff --git a/rust/tvm-rt/examples/resnet/src/main.rs 
b/rust/tvm-rt/examples/resnet/src/main.rs
new file mode 100644
index 0000000..8b74b65
--- /dev/null
+++ b/rust/tvm-rt/examples/resnet/src/main.rs
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate csv;
+extern crate image;
+extern crate ndarray;
+extern crate tvm_frontend as tvm;
+
+use std::{
+    collections::HashMap,
+    convert::TryInto,
+    fs::{self, File},
+    path::Path,
+    str::FromStr,
+};
+
+use image::{FilterType, GenericImageView};
+use ndarray::{Array, ArrayD, Axis};
+
+use tvm::*;
+
+fn main() {
+    let ctx = TVMContext::cpu(0);
+    let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), 
"/cat.png")).unwrap();
+    println!("original image dimensions: {:?}", img.dimensions());
+    // for bigger size images, one needs to first resize to 256x256
+    // with `img.resize_exact` method and then `image.crop` to 224x224
+    let img = img.resize(224, 224, FilterType::Nearest).to_rgb();
+    println!("resized image dimensions: {:?}", img.dimensions());
+    let mut pixels: Vec<f32> = vec![];
+    for pixel in img.pixels() {
+        let tmp = pixel.data;
+        // normalize the RGB channels using mean, std of imagenet1k
+        let tmp = [
+            (tmp[0] as f32 - 123.0) / 58.395, // R
+            (tmp[1] as f32 - 117.0) / 57.12,  // G
+            (tmp[2] as f32 - 104.0) / 57.375, // B
+        ];
+        for e in &tmp {
+            pixels.push(*e);
+        }
+    }
+
+    let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap();
+    let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn();
+    // make arr shape as [1, 3, 224, 224] acceptable to resnet
+    let arr = arr.insert_axis(Axis(0));
+    // create input tensor from rust's ndarray
+    let input = NDArray::from_rust_ndarray(
+        &arr,
+        TVMContext::cpu(0),
+        DLDataType::from_str("float32").unwrap(),
+    )
+    .unwrap();
+    println!(
+        "input size is {:?}",
+        input.shape().expect("cannot get the input shape")
+    );
+    let graph =
+        fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), 
"/deploy_graph.json")).unwrap();
+    // load the built module
+    let lib = Module::load(&Path::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/deploy_lib.so"
+    )))
+    .unwrap();
+    // get the global TVM graph runtime function
+    let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
+    let runtime_create_fn_ret = call_packed!(
+        runtime_create_fn,
+        graph,
+        &lib,
+        &ctx.device_type,
+        &ctx.device_id
+    )
+    .unwrap();
+    // get graph runtime module
+    let graph_runtime_module: Module = 
runtime_create_fn_ret.try_into().unwrap();
+    // get the registered `load_params` from runtime module
+    let ref load_param_fn = graph_runtime_module
+        .get_function("load_params", false)
+        .unwrap();
+    // parse parameters and convert to ByteArray
+    let params: Vec<u8> =
+        fs::read(concat!(env!("CARGO_MANIFEST_DIR"), 
"/deploy_param.params")).unwrap();
+    let barr = ByteArray::from(&params);
+    // load the parameters
+    call_packed!(load_param_fn, &barr).unwrap();
+    // get the set_input function
+    let ref set_input_fn = graph_runtime_module
+        .get_function("set_input", false)
+        .unwrap();
+
+    call_packed!(set_input_fn, "data".to_string(), &input).unwrap();
+    // get `run` function from runtime module
+    let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+    // execute the run function. Note that it has no argument
+    call_packed!(run_fn,).unwrap();
+    // prepare to get the output
+    let output_shape = &mut [1, 1000];
+    let output = NDArray::empty(
+        output_shape,
+        TVMContext::cpu(0),
+        DLDataType::from_str("float32").unwrap(),
+    );
+    // get the `get_output` function from runtime module
+    let ref get_output_fn = graph_runtime_module
+        .get_function("get_output", false)
+        .unwrap();
+    // execute the get output function
+    call_packed!(get_output_fn, &0, &output).unwrap();
+    // flatten the output as Vec<f32>
+    let output = output.to_vec::<f32>().unwrap();
+    // find the maximum entry in the output and its index
+    let mut argmax = -1;
+    let mut max_prob = 0.;
+    for i in 0..output.len() {
+        if output[i] > max_prob {
+            max_prob = output[i];
+            argmax = i as i32;
+        }
+    }
+    // create a hash map of (class id, class name)
+    let mut synset: HashMap<i32, String> = HashMap::new();
+    let file = File::open("synset.csv").unwrap();
+    let mut rdr = csv::ReaderBuilder::new()
+        .has_headers(true)
+        .from_reader(file);
+
+    for result in rdr.records() {
+        let record = result.unwrap();
+        let id: i32 = record[0].parse().unwrap();
+        let cls = record[1].to_string();
+        synset.insert(id, cls);
+    }
+
+    println!(
+        "input image belongs to the class `{}` with probability {}",
+        synset
+            .get(&argmax)
+            .expect("cannot find the class id for argmax"),
+        max_prob
+    );
+}
diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs
new file mode 100644
index 0000000..bceae5e
--- /dev/null
+++ b/rust/tvm-rt/src/context.rs
@@ -0,0 +1,76 @@
+use tvm_sys::ffi;
+pub use tvm_sys::context::*;
+
+use std::os::raw::c_void;
+use std::ptr;
+
+trait ContextExt {
+    /// Checks whether the context exists or not.
+    fn exist(&self) -> bool;
+    fn sync(&self) -> anyhow::Result<()>;
+    fn max_threads_per_block(&self) -> isize;
+    fn warp_size(&self) -> isize;
+    fn max_shared_memory_per_block(&self) -> isize;
+    fn compute_version(&self) -> isize;
+    fn device_name(&self) -> isize;
+    fn max_clock_rate(&self) -> isize;
+    fn multi_processor_count(&self) -> isize;
+    fn max_thread_dimensions(&self) -> isize;
+}
+
+macro_rules! impl_device_attrs {
+    ($(($attr_name:ident, $attr_kind:expr));+) => {
+        $(
+                fn $attr_name(&self) -> isize {
+                    get_device_attr(self.device_type.0 as i32, self.device_id 
as i32, 0)
+                        .expect("should not fail") as isize
+                }
+
+        )+
+    };
+}
+
+external_func! {
+    fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> 
i32 as "runtime.GetDeviceAttr";
+}
+
+
+impl ContextExt for Context {
+    fn exist(&self) -> bool {
+        let exists = get_device_attr(self.device_type.0 as i32, self.device_id 
as i32, 0)
+            .expect("should not fail");
+
+        exists != 0
+    }
+
+    /// Synchronize the context stream.
+    fn sync(&self) -> anyhow::Result<()> {
+        check_call!(ffi::TVMSynchronize(
+            self.device_type.0 as i32,
+            self.device_id as i32,
+            ptr::null_mut() as *mut c_void
+        ));
+        Ok(())
+    }
+
+    impl_device_attrs!((max_threads_per_block, 1);
+        (warp_size, 2);
+        (max_shared_memory_per_block, 3);
+        (compute_version, 4);
+        (device_name, 5);
+        (max_clock_rate, 6);
+        (multi_processor_count, 7);
+        (max_thread_dimensions, 8));
+}
+
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn sync() {
+        let ctx = Context::cpu(0);
+        assert!(ctx.sync().is_ok())
+    }
+}
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
new file mode 100644
index 0000000..77dbba7
--- /dev/null
+++ b/rust/tvm-rt/src/errors.rs
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use thiserror::Error;
+
+#[derive(Debug, Error)]
+#[error("Cannot convert from an empty array.")]
+pub struct EmptyArrayError;
+
+#[derive(Debug, Error)]
+#[error("Handle `{name}` is null.")]
+pub struct NullHandleError {
+    pub name: String,
+}
+
+#[derive(Debug, Error)]
+#[error("Function was not set in `function::Builder`")]
+pub struct FunctionNotFoundError;
+
+#[derive(Debug, Error)]
+#[error("Expected type `{expected}` but found `{actual}`")]
+pub struct TypeMismatchError {
+    pub expected: String,
+    pub actual: String,
+}
+
+#[derive(Debug, Error)]
+#[error("Missing NDArray shape.")]
+pub struct MissingShapeError;
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
new file mode 100644
index 0000000..739c7a0
--- /dev/null
+++ b/rust/tvm-rt/src/function.rs
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module provides an idiomatic Rust API for creating and working with 
TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use std::{
+    collections::BTreeMap,
+    ffi::{CStr, CString},
+    mem::{self, MaybeUninit},
+    os::raw::{c_char, c_int},
+    ptr, slice, str,
+    sync::Mutex,
+};
+
+use anyhow::{Result};
+use lazy_static::lazy_static;
+
+pub use tvm_sys::{ffi, ArgValue, RetValue};
+
+use super::to_function::{ToFunction, Typed};
+use super::to_boxed_fn::ToBoxedFn;
+
+lazy_static! {
+    static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<String, Option<Function>>> = {
+        let mut out_size = 0 as c_int;
+        let mut names_ptr = ptr::null_mut() as *mut *const c_char;
+        check_call!(ffi::TVMFuncListGlobalNames(
+            &mut out_size as *mut _,
+            &mut names_ptr as *mut _,
+        ));
+        let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as 
usize) };
+
+        let names_list: Vec<String> =
+            names_list
+            .iter()
+            .map(|&p| unsafe { CStr::from_ptr(p).to_str().unwrap().into() })
+            .collect();
+
+        // println!("{:?}", &names_list);
+
+        let names_list = names_list
+            .into_iter()
+            .map(|p| (p, None))
+            .collect();
+
+        Mutex::new(names_list)
+    };
+}
+
+/// Wrapper around TVM function handle which includes `is_global`
+/// indicating whether the function is global or not, and `is_cloned` showing
+/// not to drop a cloned function from Rust side.
+/// The value of these fields can be accessed through their respective methods.
+#[derive(Debug, Hash)]
+pub struct Function {
+    pub(crate) handle: ffi::TVMFunctionHandle,
+    // whether the registered function is global or not.
+    is_global: bool,
+    // whether the function has been cloned from frontend or not.
+    is_cloned: bool,
+}
+
+unsafe impl Send for Function {}
+unsafe impl Sync for Function {}
+
+impl Function {
+    pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
+        Function {
+            handle,
+            is_global: false,
+            is_cloned: false,
+        }
+    }
+
+    /// For a given function, it returns a function by name.
+    pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
+        let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
+        globals.get_mut(name.as_ref()).and_then(|maybe_func| {
+            if maybe_func.is_none() {
+                let name = CString::new(name.as_ref()).unwrap();
+                let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
+                check_call!(ffi::TVMFuncGetGlobal(
+                    name.as_ptr() as *const c_char,
+                    &mut handle as *mut _
+                ));
+                maybe_func.replace(Function {
+                    handle,
+                    is_global: true,
+                    is_cloned: false,
+                });
+            }
+
+            unsafe {
+                mem::transmute::<Option<&Function>, Option<&'static 
Function>>(maybe_func.as_ref())
+            }
+        })
+    }
+
+    /// Returns the underlying TVM function handle.
+    pub fn handle(&self) -> ffi::TVMFunctionHandle {
+        self.handle
+    }
+
+    /// Returns `true` if the underlying TVM function is global and `false` 
otherwise.
+    pub fn is_global(&self) -> bool {
+        self.is_global
+    }
+
+    /// Returns `true` if the underlying TVM function has been cloned
+    /// from the frontend and `false` otherwise.
+    pub fn is_cloned(&self) -> bool {
+        self.is_cloned
+    }
+
+    /// Calls the function that created from `Builder`.
+    pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue> {
+        let num_args = arg_buf.len();
+        let (mut values, mut type_codes): (Vec<ffi::TVMValue>, 
Vec<ffi::TVMTypeCode>) =
+            arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
+
+        let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() };
+        let mut ret_type_code = 0i32;
+        check_call!(ffi::TVMFuncCall(
+            self.handle,
+            values.as_mut_ptr(),
+            type_codes.as_mut_ptr() as *mut i32,
+            num_args as c_int,
+            &mut ret_val as *mut _,
+            &mut ret_type_code as *mut _
+        ));
+
+        Ok(RetValue::from_tvm_value(ret_val, ret_type_code as u32))
+    }
+
+    pub fn to_boxed_fn<F: ?Sized>(&'static self) -> Box<F> where F: ToBoxedFn {
+        F::to_boxed_fn(self)
+    }
+}
+
+impl Clone for Function {
+    fn clone(&self) -> Function {
+        Self {
+            handle: self.handle,
+            is_global: self.is_global,
+            is_cloned: true,
+        }
+    }
+}
+
+impl Drop for Function {
+    fn drop(&mut self) {
+        if !self.is_global && !self.is_cloned {
+            check_call!(ffi::TVMFuncFree(self.handle));
+        }
+    }
+}
+
+/// Registers a Rust function with signature
+/// `fn(&[ArgValue]) -> Result<RetValue, Error>`
+/// as a **global TVM packed function** from frontend to TVM backend.
+///
+/// Use [`register_global_func`] if overriding an existing global TVM function
+/// is not required.
+///
+/// ## Example
+///
+/// ```
+/// # use tvm_rt::{ArgValue, function, RetValue};
+/// # use tvm_rt::function::Builder;
+/// # use anyhow::Error;
+/// use std::convert::TryInto;
+///
+/// fn sum(args: &[ArgValue]) -> Result<RetValue, Error> {
+///     let mut ret = 0i64;
+///     for arg in args.iter() {
+///         let arg: i64 = arg.try_into()?;
+///         ret += arg;
+///     }
+///     let ret_val = RetValue::from(ret);
+///     Ok(ret_val)
+/// }
+///
+/// function::register(sum, "mysum".to_owned()).unwrap();
+/// let mut registered = Builder::default();
+/// registered.get_function("mysum");
+/// assert!(registered.func.is_some());
+/// let ret: i64 = registered.args(&[10, 20, 
30]).invoke().unwrap().try_into().unwrap();
+/// assert_eq!(ret, 60);
+/// ```
+pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
+where
+    F: ToFunction<I, O>,
+    F: Typed<I, O>,
+{
+    register_override(f, name, false)
+}
+
+/// Registers a Rust function with signature
+/// `fn(&[ArgValue]) -> Result<RetValue, Error>`
+/// as a **global TVM packed function** from frontend to TVM backend.
+///
+/// Use [`register_global_func`] if overriding an existing global TVM function
+/// is not required.
+///
+/// ## Example
+///
+/// ```
+/// # use tvm_rt::{ArgValue, function, RetValue};
+/// # use tvm_rt::function::Builder;
+/// # use anyhow::Error;
+/// use std::convert::TryInto;
+///
+/// fn sum(args: &[ArgValue]) -> Result<RetValue, Error> {
+///     let mut ret = 0i64;
+///     for arg in args.iter() {
+///         let arg: i64 = arg.try_into()?;
+///         ret += arg;
+///     }
+///     let ret_val = RetValue::from(ret);
+///     Ok(ret_val)
+/// }
+///
+/// function::register_override(sum, "mysum".to_owned(), false).unwrap();
+/// let mut registered = Builder::default();
+/// registered.get_function("mysum");
+/// assert!(registered.func.is_some());
+/// let ret: i64 = registered.args(&[10, 20, 
30]).invoke().unwrap().try_into().unwrap();
+/// assert_eq!(ret, 60);
+/// ```
+pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: 
bool) -> Result<()>
+where
+    F: ToFunction<I, O>,
+    F: Typed<I, O>,
+{
+    let func = f.to_function();
+    let name = name.into();
+    let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
+    // Not sure about this code
+    let handle = func.handle();
+    globals.insert(name.clone(), Some(func));
+    let name= CString::new(name)?;
+    check_call!(ffi::TVMFuncRegisterGlobal(
+        name.into_raw(),
+        handle,
+        override_ as c_int
+    ));
+
+    Ok(())
+}
+
+#[macro_export]
+macro_rules! external_func {
+    (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as 
$ext_name:literal;) => {
+        ::paste::item! {
+            #[allow(non_upper_case_globals)]
+            static [<global_ $name>]: ::once_cell::sync::Lazy<&'static 
$crate::Function> =
+            ::once_cell::sync::Lazy::new(|| {
+                $crate::Function::get($ext_name)
+                .expect(concat!("unable to load external function", 
stringify!($ext_name), "from TVM registry."))
+            });
+        }
+
+        pub fn $name($($arg : $ty),*) -> Result<$ret_type, anyhow::Error> {
+            let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ 
$name>] };
+            let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>> = 
func_ref.to_boxed_fn();
+            let res: $ret_type = func_ref($($arg),*)?;
+            Ok(res)
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::function::{Function};
+
+    static CANARY: &str = "runtime.ModuleLoadFromFile";
+
+    // #[test]
+    // fn list_global_func() {
+    //     assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
+    // }
+
+    #[test]
+    fn get_fn() {
+        assert!(Function::get(CANARY).is_some());
+        assert!(Function::get("does not exists!").is_none());
+    }
+
+    #[test]
+    fn register_and_call_closure0() {
+        use crate::function;
+
+        fn constfn() -> i64 {
+            return 10;
+        }
+
+        function::register_override(constfn, "constfn".to_owned(), 
true).unwrap();
+        let func = Function::get("constfn").unwrap();
+        let func = func.to_boxed_fn::<dyn Fn() -> Result<i32>>();
+        let ret = func().unwrap();
+        assert_eq!(ret, 10);
+    }
+
+    // #[test]
+    // fn register_and_call_closure1() {
+    //     use crate::function::{self};
+
+    //     fn ident(x: i64) -> i64 {
+    //         return x;
+    //     }
+
+    //     function::register_override(ident, "ident".to_owned(), 
false).unwrap();
+    //     let func = Function::get("ident").unwrap();
+    //     let func = func.to_boxed_fn::<dyn Fn(i32) -> Result<i32>>();
+    //     assert_eq!(func(60).unwrap(), 60);
+    // }
+}
diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs
new file mode 100644
index 0000000..e9ae02f
--- /dev/null
+++ b/rust/tvm-rt/src/lib.rs
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for 
deep learning systems.
+//!
+//! This crate provides an idiomatic Rust API for TVM runtime frontend.
+//!
+//! One particular use case is that given optimized deep learning model 
artifacts,
+//! (compiled with TVM) which include a shared library
+//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
+//! in Rust idomatically to create a TVM Graph Runtime and
+//! run the model for some inputs and get the
+//! desired predictions *all in Rust*.
+//!
+//! Checkout the `examples` repository for more details.
+
+extern crate ndarray as rust_ndarray;
+
+pub use crate as tvm_rt;
+
+pub mod object;
+pub mod string;
+
+pub use object::*;
+pub use string::*;
+
+use std::{
+    ffi::{CStr, CString},
+    str,
+};
+
+use anyhow::Error;
+
+pub use crate::{
+    context::{Context, TVMDeviceType},
+    errors::*,
+    function::Function,
+    module::Module,
+    ndarray::NDArray,
+};
+
+pub use function::{ArgValue, RetValue};
+pub use tvm_sys::byte_array::ByteArray;
+pub use tvm_sys::datatype::DataType;
+
+use tvm_sys::ffi;
+
+// Macro to check the return call to TVM runtime shared library.
+#[macro_export]
+macro_rules! check_call {
+    ($e:expr) => {{
+        if unsafe { $e } != 0 {
+            panic!("{}", $crate::get_last_error());
+        }
+    }};
+}
+
+/// Gets the last error message.
+pub fn get_last_error() -> &'static str {
+    unsafe {
+        match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
+            Ok(s) => s,
+            Err(_) => "Invalid UTF-8 message",
+        }
+    }
+}
+
+pub(crate) fn set_last_error(err: &Error) {
+    let c_string = CString::new(err.to_string()).unwrap();
+    unsafe {
+        ffi::TVMAPISetLastError(c_string.as_ptr());
+    }
+}
+
+#[macro_use]
+pub mod function;
+pub mod context;
+pub mod errors;
+pub mod module;
+pub mod ndarray;
+pub mod to_function;
+pub mod to_boxed_fn;
+pub mod value;
+
+/// Outputs the current TVM version.
+pub fn version() -> &'static str {
+    match str::from_utf8(ffi::TVM_VERSION) {
+        Ok(s) => s,
+        Err(_) => "Invalid UTF-8 string",
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn print_version() {
+        println!("TVM version: {}", version());
+    }
+
+    #[test]
+    fn set_error() {
+        let err = errors::EmptyArrayError;
+        set_last_error(&err.into());
+        assert_eq!(get_last_error().trim(), 
errors::EmptyArrayError.to_string());
+    }
+}
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
new file mode 100644
index 0000000..f9b49d9
--- /dev/null
+++ b/rust/tvm-rt/src/module.rs
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! Provides the [`Module`] type and methods for working with runtime TVM 
modules.
+
+use std::{
+    ffi::CString,
+    os::raw::{c_char, c_int},
+    path::Path,
+    ptr,
+};
+
+use anyhow::{anyhow, ensure, Error};
+use tvm_sys::ffi;
+
+use crate::{errors, function::Function};
+
+const ENTRY_FUNC: &str = "__tvm_main__";
+
+/// Wrapper around TVM module handle which contains an entry function.
+/// The entry function can be applied to an imported module through 
[`entry_func`].
+///
+/// [`entry_func`]:struct.Module.html#method.entry_func
+#[derive(Debug, Clone)]
+pub struct Module {
+    pub(crate) handle: ffi::TVMModuleHandle,
+    entry_func: Option<Function>,
+}
+
+
+external_func! {
+    fn runtime_enabled(target: CString) -> i32 as "runtime.RuntimeEnabled";
+}
+
+external_func! {
+    fn load_from_file(file_name: CString, format: CString) -> Module as 
"runtime.ModuleLoadFromFile";
+}
+
+
+impl Module {
+    pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
+        Self {
+            handle,
+            entry_func: None,
+        }
+    }
+
+    pub fn entry(&mut self) -> Option<&Function> {
+        if self.entry_func.is_none() {
+            self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
+        }
+        self.entry_func.as_ref()
+    }
+
+    /// Gets a function by name from a registered module.
+    pub fn get_function(&self, name: &str, query_import: bool) -> 
Result<Function, Error> {
+        let name = CString::new(name)?;
+        let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+        check_call!(ffi::TVMModGetFunction(
+            self.handle,
+            name.as_ptr() as *const c_char,
+            query_import as c_int,
+            &mut fhandle as *mut _
+        ));
+        ensure!(
+            !fhandle.is_null(),
+            errors::NullHandleError {
+                name: name.into_string()?.to_string()
+            }
+        );
+        Ok(Function::new(fhandle))
+    }
+
+    /// Imports a dependent module such as `.ptx` for gpu.
+    pub fn import_module(&self, dependent_module: Module) {
+        check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
+    }
+
+    /// Loads a module shared library from path.
+    pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
+        let ext = CString::new(
+            path.as_ref()
+                .extension()
+                .unwrap_or_else(|| std::ffi::OsStr::new(""))
+                .to_str()
+                .ok_or_else(|| anyhow!("Bad module load path: `{}`.", 
path.as_ref().display()))?,
+        )?;
+        let cpath = CString::new(
+            path.as_ref()
+                .to_str()
+                .ok_or_else(|| anyhow!("Bad module load path: `{}`.", 
path.as_ref().display()))?,
+        )?;
+        let module = load_from_file(cpath, ext)?;
+        Ok(module)
+    }
+
+    /// Checks if a target device is enabled for a module.
+    pub fn enabled(&self, target: &str) -> bool {
+        let target = CString::new(target).unwrap();
+        let enabled = runtime_enabled(target).unwrap();
+        enabled != 0
+    }
+
+    /// Returns the underlying module handle.
+    pub fn handle(&self) -> ffi::TVMModuleHandle {
+        self.handle
+    }
+}
+
+impl Drop for Module {
+    fn drop(&mut self) {
+        check_call!(ffi::TVMModFree(self.handle));
+    }
+}
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
new file mode 100644
index 0000000..4653117
--- /dev/null
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module implements the [`NDArray`] type for working with *TVM tensors* 
or
+//! coverting from a Rust's ndarray to TVM `NDArray`.
+//!
+//! One can create an empty NDArray given the shape, context and dtype using 
[`empty`].
+//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`].
+//! To copy an NDArray to different context use [`copy_to_ctx`].
+//!
+//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as 
follows:
+//!
+//! # Example
+//!
+//! ```
+//! # use tvm_rt::{NDArray, Context, DataType};
+//! # use ndarray::{Array, ArrayD};
+//! # use std::str::FromStr;
+//! use std::convert::TryFrom;
+//!
+//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+//!     .unwrap()
+//!     .into_dyn(); // Rust's ndarray
+//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), 
DataType::from_str("float32").unwrap()).unwrap();
+//! assert_eq!(nd.shape(), Some(&mut [2, 2][..]));
+//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+//! assert!(rnd.all_close(&a, 1e-8f32));
+//! ```
+//!
+//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/
+//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
+//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
+
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
+
+use crate::errors;
+use anyhow::{bail, ensure, Result};
+use num_traits::Num;
+use rust_ndarray::{Array, ArrayD};
+use std::convert::TryInto;
+use std::ffi::c_void;
+use tvm_sys::ffi::DLTensor;
+use tvm_sys::{ffi, ByteArray, Context, DataType};
+
+/// See the [`module-level documentation`](../ndarray/index.html) for more 
details.
+///
+/// Wrapper around TVM array handle.
+#[derive(Debug)]
+pub enum NDArray {
+    Borrowed { handle: ffi::TVMArrayHandle },
+    Owned { handle: *mut c_void },
+}
+
+impl NDArray {
+    pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
+        NDArray::Borrowed { handle }
+    }
+
+    pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
+        NDArray::Owned { handle }
+    }
+
+    pub fn as_dltensor(&self) -> &DLTensor {
+        unsafe {
+            match self {
+                NDArray::Borrowed { ref handle } => 
std::mem::transmute(*handle),
+                NDArray::Owned { ref handle } => std::mem::transmute(*handle),
+            }
+        }
+    }
+
+    pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
+        unsafe {
+            match self {
+                NDArray::Borrowed { ref handle } => 
std::mem::transmute(*handle),
+                NDArray::Owned { ref handle } => std::mem::transmute(*handle),
+            }
+        }
+    }
+
+    pub fn is_view(&self) -> bool {
+        if let &NDArray::Borrowed { .. } = self {
+            true
+        } else {
+            false
+        }
+    }
+
+    /// Returns the shape of the NDArray.
+    pub fn shape(&self) -> Option<&mut [usize]> {
+        let arr = self.as_dltensor();
+        if arr.shape.is_null() || arr.data.is_null() {
+            return None;
+        };
+        let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, 
arr.ndim as usize) };
+        Some(slc)
+    }
+
+    /// Returns the total number of entries of the NDArray.
+    pub fn size(&self) -> Option<usize> {
+        self.shape().map(|v| v.iter().product())
+    }
+
+    /// Returns the context which the NDArray was defined.
+    pub fn ctx(&self) -> Context {
+        self.as_dltensor().ctx.into()
+    }
+
+    /// Returns the type of the entries of the NDArray.
+    pub fn dtype(&self) -> DataType {
+        self.as_dltensor().dtype.into()
+    }
+
+    /// Returns the number of dimensions of the NDArray.
+    pub fn ndim(&self) -> usize {
+        self.as_dltensor()
+            .ndim
+            .try_into()
+            .expect("number of dimensions must always be positive")
+    }
+
+    /// Returns the strides of the underlying NDArray.
+    pub fn strides(&self) -> Option<&[usize]> {
+        unsafe {
+            let sz = self.ndim() * mem::size_of::<usize>();
+            let strides_ptr = self.as_dltensor().strides as *const usize;
+            let slc = slice::from_raw_parts(strides_ptr, sz);
+            Some(slc)
+        }
+    }
+
+    /// Shows whether the underlying ndarray is contiguous in memory or not.
+    pub fn is_contiguous(&self) -> Result<bool> {
+        Ok(match self.strides() {
+            None => true,
+            Some(strides) => {
+                // errors::MissingShapeError in case shape is not determined
+                self.shape()
+                    .ok_or(errors::MissingShapeError)?
+                    .iter()
+                    .zip(strides)
+                    .rfold(
+                        (true, 1),
+                        |(is_contig, expected_stride), (shape, stride)| {
+                            (
+                                is_contig && *stride == expected_stride,
+                                expected_stride * (*shape as usize),
+                            )
+                        },
+                    )
+                    .0
+            }
+        })
+    }
+
+    pub fn byte_offset(&self) -> isize {
+        self.as_dltensor().byte_offset as isize
+    }
+
+    /// Flattens the NDArray to a `Vec` of the same type in cpu.
+    ///
+    /// ## Example
+    ///
+    /// ```
+    /// # use tvm_rt::{Context, DataType, NDArray};
+    /// # use std::str::FromStr;
+    /// let mut shape = [4];
+    /// let mut data = vec![1i32, 2, 3, 4];
+    /// let ctx = Context::cpu(0);
+    /// let mut ndarray = NDArray::empty(&mut shape, ctx, 
DataType::from_str("int32").unwrap());
+    /// ndarray.copy_from_buffer(&mut data);
+    /// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
+    /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+    /// ```
+    pub fn to_vec<T>(&self) -> Result<Vec<T>> {
+        ensure!(self.shape().is_some(), errors::EmptyArrayError);
+        let earr = NDArray::empty(
+            self.shape().ok_or(errors::MissingShapeError)?,
+            Context::cpu(0),
+            self.dtype(),
+        );
+        let target = self.copy_to_ndarray(earr)?;
+        let arr = target.as_dltensor();
+        let sz = self.size().ok_or(errors::MissingShapeError)?;
+        let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
+        unsafe {
+            v.as_mut_ptr()
+                .copy_from_nonoverlapping(arr.data as *const T, sz);
+            v.set_len(sz);
+        }
+        Ok(v)
+    }
+
+    /// Converts the NDArray to [`ByteArray`].
+    pub fn to_bytearray(&self) -> Result<ByteArray> {
+        let v = self.to_vec::<u8>()?;
+        Ok(ByteArray::from(v))
+    }
+
+    /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in 
cpu.
+    ///
+    /// ## Example
+    ///
+    /// ```
+    /// # use tvm_rt::{Context, DataType, NDArray};
+    /// # use std::str::FromStr;
+    /// let shape = &mut [2];
+    /// let mut data = vec![1f32, 2.0];
+    /// let ctx = Context::cpu(0);
+    /// let mut ndarray = NDArray::empty(shape, ctx, 
DataType::from_str("int32").unwrap());
+    /// ndarray.copy_from_buffer(&mut data);
+    /// ```
+    ///
+    /// *Note*: if something goes wrong during the copy, it will panic
+    /// from TVM side. See `TVMArrayCopyFromBytes` in 
`include/tvm/runtime/c_runtime_api.h`.
+    pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
+        check_call!(ffi::TVMArrayCopyFromBytes(
+            self.as_raw_dltensor(),
+            data.as_ptr() as *mut _,
+            data.len() * mem::size_of::<T>()
+        ));
+    }
+
+    /// Copies the NDArray to another target NDArray.
+    pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
+        if self.dtype() != target.dtype() {
+            bail!(
+                "{}",
+                errors::TypeMismatchError {
+                    expected: self.dtype().to_string(),
+                    actual: target.dtype().to_string(),
+                }
+            );
+        }
+        check_call!(ffi::TVMArrayCopyFromTo(
+            self.as_raw_dltensor(),
+            target.as_raw_dltensor(),
+            ptr::null_mut() as ffi::TVMStreamHandle
+        ));
+        Ok(target)
+    }
+
+    /// Copies the NDArray to a target context.
+    pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray> {
+        let tmp = NDArray::empty(
+            self.shape().ok_or(errors::MissingShapeError)?,
+            *target,
+            self.dtype(),
+        );
+        let copy = self.copy_to_ndarray(tmp)?;
+        Ok(copy)
+    }
+
+    /// Converts a Rust's ndarray to TVM NDArray.
+    pub fn from_rust_ndarray<T: Num32 + Copy>(
+        rnd: &ArrayD<T>,
+        ctx: Context,
+        dtype: DataType,
+    ) -> Result<Self> {
+        let shape = rnd.shape().to_vec();
+        let mut nd = NDArray::empty(&shape, ctx, dtype);
+        let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
+        nd.copy_from_buffer(
+            buf.as_slice_mut()
+                .expect("Array from iter must be contiguous."),
+        );
+        Ok(nd)
+    }
+
+    /// Allocates and creates an empty NDArray given the shape, context and 
dtype.
+    pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray {
+        let mut handle = ptr::null_mut() as ffi::TVMArrayHandle;
+        check_call!(ffi::TVMArrayAlloc(
+            shape.as_ptr() as *const i64,
+            shape.len() as c_int,
+            i32::from(dtype.code) as c_int,
+            i32::from(dtype.bits) as c_int,
+            i32::from(dtype.lanes) as c_int,
+            ctx.device_type.0 as c_int,
+            ctx.device_id as c_int,
+            &mut handle as *mut _,
+        ));
+        NDArray::Borrowed { handle: handle }
+    }
+}
+
+macro_rules! impl_from_ndarray_rustndarray {
+    ($type:ty, $type_name:tt) => {
+        impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
+            type Error = anyhow::Error;
+            fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
+                ensure!(nd.shape().is_some(), errors::MissingShapeError);
+                assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type 
mismatch");
+                Ok(Array::from_shape_vec(
+                    &*nd.shape().ok_or(errors::MissingShapeError)?,
+                    nd.to_vec::<$type>()?,
+                )?)
+            }
+        }
+
+        impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
+            type Error = anyhow::Error;
+            fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
+                ensure!(nd.shape().is_some(), errors::MissingShapeError);
+                assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type 
mismatch");
+                Ok(Array::from_shape_vec(
+                    &*nd.shape().ok_or(errors::MissingShapeError)?,
+                    nd.to_vec::<$type>()?,
+                )?)
+            }
+        }
+    };
+}
+
+impl_from_ndarray_rustndarray!(i32, "int");
+impl_from_ndarray_rustndarray!(u32, "uint");
+impl_from_ndarray_rustndarray!(f32, "float");
+
+impl Drop for NDArray {
+    fn drop(&mut self) {
+        if let &mut NDArray::Owned { .. } = self {
+            check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
+        }
+    }
+}
+
+mod sealed {
+    /// Private trait to prevent other traits from being implemeneted in 
downstream crates.
+    pub trait Sealed {}
+}
+
+/// A trait for the supported 32-bits numerical types in frontend.
+pub trait Num32: Num + sealed::Sealed {
+    const BITS: u8 = 32;
+}
+
+macro_rules! impl_num32 {
+    ($($type:ty),+) => {
+        $(
+            impl sealed::Sealed for $type {}
+            impl Num32 for $type {}
+        )+
+    };
+}
+
+impl_num32!(i32, u32, f32);
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn basics() {
+        let shape = &mut [1, 2, 3];
+        let ctx = Context::cpu(0);
+        let ndarray = NDArray::empty(shape, ctx, 
DataType::from_str("int32").unwrap());
+        assert_eq!(ndarray.shape().unwrap(), shape);
+        assert_eq!(
+            ndarray.size().unwrap(),
+            shape.to_vec().into_iter().product()
+        );
+        assert_eq!(ndarray.ndim(), 3);
+        assert!(ndarray.strides().is_none());
+        assert_eq!(ndarray.byte_offset(), 0);
+    }
+
+    #[test]
+    fn copy() {
+        let shape = &mut [4];
+        let mut data = vec![1i32, 2, 3, 4];
+        let ctx = Context::cpu(0);
+        let mut ndarray = NDArray::empty(shape, ctx, 
DataType::from_str("int32").unwrap());
+        assert!(ndarray.to_vec::<i32>().is_ok());
+        ndarray.copy_from_buffer(&mut data);
+        assert_eq!(ndarray.shape().unwrap(), shape);
+        assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+        assert_eq!(ndarray.ndim(), 1);
+        assert!(ndarray.is_contiguous().is_ok());
+        assert_eq!(ndarray.byte_offset(), 0);
+        let shape = vec![4];
+        let e = NDArray::empty(
+            &shape,
+            Context::cpu(0),
+            DataType::from_str("int32").unwrap(),
+        );
+        let nd = ndarray.copy_to_ndarray(e);
+        assert!(nd.is_ok());
+        assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
+    }
+
+    #[test]
+    #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
+    fn copy_wrong_dtype() {
+        let shape = vec![4];
+        let mut data = vec![1f32, 2., 3., 4.];
+        let ctx = Context::cpu(0);
+        let mut nd_float = NDArray::empty(&shape, ctx, 
DataType::from_str("float32").unwrap());
+        nd_float.copy_from_buffer(&mut data);
+        let empty_int = NDArray::empty(&shape, ctx, 
DataType::from_str("int32").unwrap());
+        nd_float.copy_to_ndarray(empty_int).unwrap();
+    }
+
+    #[test]
+    fn rust_ndarray() {
+        let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+            .unwrap()
+            .into_dyn();
+        let nd =
+            NDArray::from_rust_ndarray(&a, Context::cpu(0), 
DataType::from_str("float32").unwrap())
+                .unwrap();
+        assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
+        let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+        assert!(rnd.all_close(&a, 1e-8f32));
+    }
+}
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
new file mode 100644
index 0000000..8d8efdf
--- /dev/null
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -0,0 +1,99 @@
+use std::convert::TryFrom;
+use std::convert::TryInto;
+use std::ffi::CString;
+use tvm_sys::{ArgValue, RetValue};
+use crate::external_func;
+
+mod object_ptr;
+
+pub use object_ptr::{IsObject, Object, ObjectPtr};
+
+#[derive(Clone)]
+pub struct ObjectRef(pub Option<ObjectPtr<Object>>);
+
+impl ObjectRef {
+    pub fn null() -> ObjectRef {
+        ObjectRef(None)
+    }
+}
+
+pub trait ToObjectRef {
+    fn to_object_ref(&self) -> ObjectRef;
+}
+
+impl ToObjectRef for ObjectRef {
+    fn to_object_ref(&self) -> ObjectRef {
+        self.clone()
+    }
+}
+
+// impl<T: ToObjectRef> ToObjectRef for &T {
+//     fn to_object_ref(&self) -> ObjectRef {
+//         (*self).to_object_ref()
+//     }
+// }
+
+impl TryFrom<RetValue> for ObjectRef {
+    type Error = anyhow::Error;
+
+    fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> {
+        let optr = ret_val.try_into()?;
+        Ok(ObjectRef(Some(optr)))
+    }
+}
+
+impl From<ObjectRef> for RetValue {
+    fn from(object_ref: ObjectRef) -> RetValue {
+        use std::ffi::c_void;
+        let object_ptr = &object_ref.0;
+        match object_ptr {
+            None => RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut 
c_void),
+            Some(value) => value.clone().into(),
+        }
+    }
+}
+
+impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef {
+    type Error = anyhow::Error;
+
+    fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
+        let optr = arg_value.try_into()?;
+        Ok(ObjectRef(Some(optr)))
+    }
+}
+
+impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef {
+    type Error = anyhow::Error;
+
+    fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
+        // TODO(@jroesch): remove the clone
+        let value: ArgValue<'a> = arg_value.clone();
+        ObjectRef::try_from(value)
+    }
+}
+
+impl<'a> From<ObjectRef> for ArgValue<'a> {
+    fn from(object_ref: ObjectRef) -> ArgValue<'a> {
+        use std::ffi::c_void;
+        let object_ptr = &object_ref.0;
+        match object_ptr {
+            None => ArgValue::ObjectHandle(std::ptr::null::<c_void>() as *mut 
c_void),
+            Some(value) => value.clone().into(),
+        }
+    }
+}
+
+impl<'a> From<&ObjectRef> for ArgValue<'a> {
+    fn from(object_ref: &ObjectRef) -> ArgValue<'a> {
+        let oref: ObjectRef = object_ref.clone();
+        ArgValue::<'a>::from(oref)
+    }
+}
+
+external_func! {
+    fn debug_print(object: ObjectRef) -> CString as "ir.DebugPrinter";
+}
+
+external_func! {
+    fn as_text(object: ObjectRef) -> CString as "ir.TextPrinter";
+}
diff --git a/rust/tvm-rt/src/object/object_ptr.rs 
b/rust/tvm-rt/src/object/object_ptr.rs
new file mode 100644
index 0000000..c716c05
--- /dev/null
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -0,0 +1,283 @@
+use anyhow::Context;
+use std::convert::TryFrom;
+use std::ffi::CString;
+use std::ptr::NonNull;
+use tvm_sys::ffi::{self, /* TVMObjectFree, */ TVMObjectRetain, 
TVMObjectTypeKey2Index};
+use tvm_sys::{ArgValue, RetValue};
+
+type Deleter<T> = unsafe extern "C" fn(object: *mut T) -> ();
+
+#[derive(Debug)]
+#[repr(C)]
+pub struct Object {
+    pub type_index: u32,
+    pub ref_count: i32,
+    pub fdeleter: Deleter<Object>,
+}
+
+unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) {
+    let typed_object: *mut T = std::mem::transmute(object);
+    T::typed_delete(typed_object);
+}
+
+fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool {
+    let mut is_derived = 0;
+    crate::check_call!(ffi::TVMObjectDerivedFrom(
+        child_type_index,
+        parent_type_index,
+        &mut is_derived
+    ));
+    if is_derived == 0 {
+        false
+    } else {
+        true
+    }
+}
+
+impl Object {
+    fn new(type_index: u32, deleter: Deleter<Object>) -> Object {
+        Object {
+            type_index,
+            // Note: do not touch this field directly again, this is
+            // a critical section, we write a 1 to the atomic which will now
+            // be managed by the C++ atomics.
+            // In the future we should probably use C-atomcis.
+            ref_count: 1,
+            fdeleter: deleter,
+        }
+    }
+
+    fn get_type_index<T: IsObject>() -> u32 {
+        let type_key = T::TYPE_KEY;
+        let cstring = CString::new(type_key).expect("type key must not contain 
null characters");
+        if type_key == "Object" {
+            return 0;
+        } else {
+            let mut index = 0;
+            unsafe {
+                let index_ptr = std::mem::transmute(&mut index);
+                if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 {
+                    panic!(crate::get_last_error())
+                }
+            }
+            return index;
+        }
+    }
+
+    pub fn base_object<T: IsObject>() -> Object {
+        let index = Object::get_type_index::<T>();
+        Object::new(index, delete::<T>)
+    }
+}
+
+pub unsafe trait IsObject {
+    const TYPE_KEY: &'static str;
+
+    fn as_object<'s>(&'s self) -> &'s Object;
+
+    unsafe extern "C" fn typed_delete(_object: *mut Self) {
+        // let object = Box::from_raw(object);
+        // drop(object)
+    }
+}
+
+unsafe impl IsObject for Object {
+    const TYPE_KEY: &'static str = "Object";
+
+    fn as_object<'s>(&'s self) -> &'s Object {
+        self
+    }
+}
+
+#[repr(C)]
+pub struct ObjectPtr<T> {
+    pub ptr: NonNull<T>,
+}
+
+impl ObjectPtr<Object> {
+    fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
+        println!("{:?}", object_ptr);
+        let non_null = NonNull::new(object_ptr);
+        non_null.map(|ptr| ObjectPtr { ptr })
+    }
+}
+
+impl<T> Clone for ObjectPtr<T> {
+    fn clone(&self) -> Self {
+        unsafe {
+            let raw_ptr = std::mem::transmute(self.ptr);
+            assert_eq!(TVMObjectRetain(raw_ptr), 0);
+            ObjectPtr { ptr: self.ptr }
+        }
+    }
+}
+
+// impl<T> Drop for ObjectPtr<T> {
+//     fn drop(&mut self) {
+//         unsafe {
+//             let raw_ptr = std::mem::transmute(self.ptr);
+//             assert_eq!(TVMObjectFree(raw_ptr), 0)
+//         }
+//     }
+// }
+
+impl<T: IsObject> ObjectPtr<T> {
+    pub fn new(object: T) -> ObjectPtr<T> {
+        let object_ptr = Box::new(object);
+        let ptr = NonNull::from(Box::leak(object_ptr));
+        ObjectPtr { ptr }
+    }
+
+    pub fn count(&self) -> i32 {
+        // need to do atomic read in C++
+        // ABI compatible atomics is funky/hard.
+        self.as_object().ref_count
+    }
+
+    fn as_object<'s>(&'s self) -> &'s Object {
+        unsafe { self.ptr.as_ref().as_object() }
+    }
+
+    pub fn upcast(&self) -> ObjectPtr<Object> {
+        ObjectPtr {
+            ptr: self.ptr.cast(),
+        }
+    }
+
+    pub fn downcast<U: IsObject>(&self) -> anyhow::Result<ObjectPtr<U>> {
+        let child_index = Object::get_type_index::<U>();
+        let object_index = self.as_object().type_index;
+
+        let is_derived = if child_index == object_index {
+            true
+        } else {
+            // TODO(@jroesch): write tests
+            derived_from(object_index, child_index)
+        };
+
+        if is_derived {
+            Ok(ObjectPtr {
+                ptr: self.ptr.cast(),
+            })
+        } else {
+            Err(anyhow::anyhow!("failed to downcast to object subtype"))
+        }
+    }
+}
+
+impl<T> std::ops::Deref for ObjectPtr<T> {
+    type Target = T;
+
+    fn deref(&self) -> &Self::Target {
+        unsafe { self.ptr.as_ref() }
+    }
+}
+
+impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
+    fn from(object_ptr: ObjectPtr<T>) -> RetValue {
+        let raw_object_ptr = object_ptr.ptr.as_ptr();
+        // Should be able to hide this unsafety in raw bindings.
+        let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) };
+        RetValue::ObjectHandle(void_ptr)
+    }
+}
+
+impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
+    type Error = anyhow::Error;
+
+    fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> {
+        match ret_value {
+            RetValue::ObjectHandle(handle) => {
+                let handle: *mut Object = unsafe { std::mem::transmute(handle) 
};
+                let optr = ObjectPtr::from_raw(handle).context("unable to 
convert nullptr")?;
+                optr.downcast()
+            }
+            _ => Err(anyhow::anyhow!("unable to convert the result to an 
Object")),
+        }
+    }
+}
+
+impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
+    fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
+        let raw_object_ptr = object_ptr.ptr.as_ptr();
+        // Should be able to hide this unsafety in raw bindings.
+        let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) };
+        ArgValue::ObjectHandle(void_ptr)
+    }
+}
+
+impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
+    type Error = anyhow::Error;
+    fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
+        match arg_value {
+            ArgValue::ObjectHandle(handle) => {
+                let handle = unsafe { std::mem::transmute(handle) };
+                let optr = ObjectPtr::from_raw(handle).context("unable to 
convert nullptr")?;
+                optr.downcast()
+            }
+            _ => Err(anyhow::anyhow!("unable to convert the result to an 
Object")),
+        }
+    }
+}
+
+impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr<T> {
+    type Error = anyhow::Error;
+    fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> 
{
+        match arg_value {
+            ArgValue::ObjectHandle(handle) => {
+                let handle = unsafe { std::mem::transmute(handle) };
+                let optr = ObjectPtr::from_raw(handle).context("unable to 
convert nullptr")?;
+                optr.downcast()
+            }
+            _ => Err(anyhow::anyhow!("unable to convert the result to an 
Object")),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{Object, ObjectPtr};
+    use anyhow::{ensure, Result};
+    use std::convert::TryInto;
+    use tvm_sys::{ArgValue, RetValue};
+
+    #[test]
+    fn test_new_object() -> anyhow::Result<()> {
+        let object = Object::base_object::<Object>();
+        let ptr = ObjectPtr::new(object);
+        assert_eq!(ptr.count(), 1);
+        Ok(())
+    }
+
+    #[test]
+    fn roundtrip_retvalue() -> Result<()> {
+        let ptr = ObjectPtr::new(Object::base_object::<Object>());
+        let ret_value: RetValue = ptr.clone().into();
+        let ptr2: ObjectPtr<Object> = ret_value.try_into()?;
+        ensure!(
+            ptr.type_index == ptr2.type_index,
+            "type indices do not match"
+        );
+        ensure!(
+            ptr.fdeleter == ptr2.fdeleter,
+            "objects have different deleters"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn roundtrip_argvalue() -> Result<()> {
+        let ptr = ObjectPtr::new(Object::base_object::<Object>());
+        let arg_value: ArgValue = ptr.clone().into();
+        let ptr2: ObjectPtr<Object> = arg_value.try_into()?;
+        ensure!(
+            ptr.type_index == ptr2.type_index,
+            "type indices do not match"
+        );
+        ensure!(
+            ptr.fdeleter == ptr2.fdeleter,
+            "objects have different deleters"
+        );
+        Ok(())
+    }
+}
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
new file mode 100644
index 0000000..ac80625
--- /dev/null
+++ b/rust/tvm-rt/src/string.rs
@@ -0,0 +1,72 @@
+use std::ffi::{CString, NulError};
+use std::os::raw::c_char;
+
+use super::{Object, ObjectPtr, ObjectRef};
+use crate as tvm_rt;
+use tvm_macros::Object;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "String"]
+#[type_key = "runtime.String"]
+pub struct StringObj {
+    base: Object,
+    data: *const c_char,
+    size: u64,
+}
+
+impl String {
+    pub fn new(string: std::string::String) -> Result<String, NulError> {
+        let cstring = CString::new(string)?;
+
+        // The string is being corrupted.
+        // why is this wrong
+        let length = cstring.as_bytes().len();
+
+        let string_obj = StringObj {
+            base: Object::base_object::<StringObj>(),
+            data: cstring.into_raw(),
+            size: length as u64,
+        };
+
+        let object_ptr = ObjectPtr::new(string_obj);
+        Ok(String(Some(object_ptr)))
+    }
+
+    pub fn to_cstring(&self) -> Result<std::ffi::CString, NulError> {
+        use std::slice;
+        let ptr = self.0.as_ref().unwrap().data;
+        let size = self.0.as_ref().unwrap().size;
+        unsafe {
+            let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as 
usize);
+            CString::new(slice)
+        }
+    }
+
+    pub fn to_string(&self) -> anyhow::Result<std::string::String> {
+        let string = self.to_cstring()?.into_string()?;
+        Ok(string)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::String;
+    use crate::object::debug_print;
+    use crate::ToObjectRef;
+    use anyhow::{ensure, Result};
+
+    #[test]
+    fn test_string_debug() -> Result<()> {
+        let s = String::new("foo".to_string()).unwrap();
+        let object_ref = s.to_object_ref();
+        println!("about to call");
+        let string = debug_print(object_ref)?;
+        println!("after call");
+        ensure!(
+            string.into_string().expect("is cstring").contains("foo"),
+            "string content is invalid"
+        );
+        Ok(())
+    }
+}
diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs
new file mode 100644
index 0000000..7a560b6
--- /dev/null
+++ b/rust/tvm-rt/src/to_boxed_fn.rs
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module provides an idiomatic Rust API for creating and working with 
TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use anyhow::Result;
+
+pub use tvm_sys::{ffi, ArgValue, RetValue};
+
+use crate::{Module};
+
+use super::function::Function;
+
+pub trait ToBoxedFn {
+    fn to_boxed_fn(func: &'static Function) -> Box<Self>;
+}
+
+use std::convert::{TryInto, TryFrom};
+
+impl<E, O> ToBoxedFn for dyn Fn() -> Result<O>
+    where E: std::error::Error + Send + Sync + 'static,
+          O: TryFrom<RetValue, Error=E>, {
+    fn to_boxed_fn(func: &'static Function) -> Box<Self> {
+        Box::new(move || {
+            let mut builder = Builder::default();
+            builder.func = Some(func);
+            let res = builder.invoke()?.try_into()?;
+            Ok(res)
+        })
+    }
+}
+
+impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
+    where E: std::error::Error + Send + Sync + 'static,
+          A: Into<ArgValue<'static>>,
+          O: TryFrom<RetValue, Error=E>, {
+    fn to_boxed_fn(func: &'static Function) -> Box<Self> {
+        Box::new(move |a: A| {
+            let mut builder = Builder::default();
+            builder.func = Some(func);
+            builder.arg(a.into());
+            let res = builder.invoke()?.try_into()?;
+            Ok(res)
+        })
+    }
+}
+
+impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
+    where E: std::error::Error + Send + Sync + 'static,
+          A: Into<ArgValue<'static>>,
+          B: Into<ArgValue<'static>>,
+          O: TryFrom<RetValue, Error=E>, {
+    fn to_boxed_fn(func: &'static Function) -> Box<Self> {
+        Box::new(move |a: A, b: B| {
+            let mut builder = Builder::default();
+            builder.func = Some(func);
+            builder.arg(a.into());
+            builder.arg(b.into());
+            let res = builder.invoke()?.try_into()?;
+            Ok(res)
+        })
+    }
+}
+
+impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
+    where E: std::error::Error + Send + Sync + 'static,
+          A: Into<ArgValue<'static>>,
+          B: Into<ArgValue<'static>>,
+          C: Into<ArgValue<'static>>,
+          O: TryFrom<RetValue, Error=E>, {
+    fn to_boxed_fn(func: &'static Function) -> Box<Self> {
+        Box::new(move |a: A, b: B, c: C| {
+            let mut builder = Builder::default();
+            builder.func = Some(func);
+            builder.arg(a.into());
+            builder.arg(b.into());
+            builder.arg(c.into());
+            let res = builder.invoke()?.try_into()?;
+            Ok(res)
+        })
+    }
+}
+
+impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
+    where E: std::error::Error + Send + Sync + 'static,
+          A: Into<ArgValue<'static>>,
+          B: Into<ArgValue<'static>>,
+          C: Into<ArgValue<'static>>,
+          D: Into<ArgValue<'static>>,
+          O: TryFrom<RetValue, Error=E>, {
+    fn to_boxed_fn(func: &'static Function) -> Box<Self> {
+        Box::new(move |a: A, b: B, c: C, d: D| {
+            let mut builder = Builder::default();
+            builder.func = Some(func);
+            builder.arg(a.into());
+            builder.arg(b.into());
+            builder.arg(c.into());
+            builder.arg(d.into());
+            let res = builder.invoke()?.try_into()?;
+            Ok(res)
+        })
+    }
+}
+
+/// Function builder in order to create and call functions.
+///
+/// *Note:* Currently TVM functions accept *at most* one return value.
+#[derive(Default)]
+pub struct Builder<'a, 'm> {
+    pub func: Option<&'m Function>,
+    pub arg_buf: Vec<ArgValue<'a>>,
+    pub ret_buf: Option<RetValue>,
+}
+
+impl<'a, 'm> Builder<'a, 'm> {
+    pub fn new(
+        func: Option<&'m Function>,
+        arg_buf: Vec<ArgValue<'a>>,
+        ret_buf: Option<RetValue>,
+    ) -> Self {
+        Self {
+            func,
+            arg_buf,
+            ret_buf,
+        }
+    }
+
+    pub fn get_function(&mut self, name: &'m str) -> &mut Self {
+        self.func = Function::get(name);
+        self
+    }
+
+    /// Pushes a [`ArgValue`] into the function argument buffer.
+    pub fn arg<T: 'a>(&mut self, arg: T) -> &mut Self
+    where
+        ArgValue<'a>: From<T>,
+    {
+        self.arg_buf.push(arg.into());
+        self
+    }
+
+    /// Pushes multiple [`ArgValue`]s into the function argument buffer.
+    pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
+    where
+        I: IntoIterator<Item = T>,
+        ArgValue<'a>: From<T>,
+    {
+        args.into_iter().for_each(|arg| {
+            self.arg(arg);
+        });
+        self
+    }
+
+    /// Sets an output for a function that requirs a mutable output to be 
provided.
+    /// See the `basics` in tests for an example.
+    pub fn set_output<T>(&mut self, ret: T) -> &mut Self
+    where
+        RetValue: From<T>,
+    {
+        self.ret_buf = Some(ret.into());
+        self
+    }
+
+    pub fn invoke(self) -> Result<RetValue> {
+        self.func.unwrap().invoke(self.arg_buf)
+    }
+
+}
+
+/// Converts a [`Function`] to builder. Currently, this is the best way to 
work with
+/// TVM functions.
+impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
+    fn from(func: &'m Function) -> Self {
+        Builder::new(Some(func), Vec::new(), None)
+    }
+}
+
+/// Converts a mutable reference of a [`Module`] to [`Builder`].
+impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
+    fn from(module: &'m mut Module) -> Self {
+        Builder::new(module.entry(), Vec::new(), None)
+    }
+}
+#[cfg(test)]
+mod tests {
+    use anyhow::Result;
+    use crate::function::{self, Function};
+
+    #[test]
+    fn to_boxed_fn0() {
+        fn boxed0() -> i64 {
+            return 10;
+        }
+
+        function::register_override(boxed0, "boxed0".to_owned(), 
true).unwrap();
+        let func = Function::get("boxed0").unwrap();
+        let typed_func: Box<dyn Fn() -> Result<i64>> = func.to_boxed_fn();
+        assert_eq!(typed_func().unwrap(), 10);
+    }
+}
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
new file mode 100644
index 0000000..6954650
--- /dev/null
+++ b/rust/tvm-rt/src/to_function.rs
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module provides an idiomatic Rust API for creating and working with 
TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use std::{
+    mem::MaybeUninit,
+    os::raw::{c_int, c_void},
+    ptr, slice,
+};
+
+use anyhow::Result;
+
+pub use tvm_sys::{ffi, ArgValue, RetValue};
+
+use super::Function;
+use std::convert::{TryFrom, TryInto};
+
+/// A trait representing whether the function arguments
+/// and return type can be assigned to a TVM packed function.
+///
+/// By splitting the conversion to function into two traits
+/// we are able to improve error reporting, by splitting the
+/// conversion of inputs and outputs to this trait.
+///
+/// And the implementation of it to `ToFunction`.
+pub trait Typed<I, O> {
+    fn args(i: &[ArgValue<'static>]) -> anyhow::Result<I>;
+    fn ret(o: O) -> RetValue;
+}
+
+impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>> for F
+where
+    F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>,
+{
+    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<&'a 
[ArgValue<'static>]> {
+        // this is BAD but just hacking for time being
+        Ok(unsafe { std::mem::transmute(args) })
+    }
+
+    fn ret(ret_value: anyhow::Result<RetValue>) -> RetValue {
+        ret_value.unwrap()
+    }
+}
+
+impl<F, O: Into<RetValue>> Typed<(), O> for F
+where
+    F: Fn() -> O,
+{
+    fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<()> {
+        debug_assert!(_args.len() == 0);
+        Ok(())
+    }
+
+    fn ret(o: O) -> RetValue {
+        o.into()
+    }
+}
+
+impl<F, A, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,), O> for F
+where
+    F: Fn(A) -> O,
+    E: std::error::Error + Send + Sync + 'static,
+    A: TryFrom<ArgValue<'static>, Error = E>,
+{
+    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,)> {
+        debug_assert!(args.len() == 1);
+        let a: A = args[0].clone().try_into()?;
+        Ok((a,))
+    }
+
+    fn ret(o: O) -> RetValue {
+        o.into()
+    }
+}
+
+impl<F, A, B, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B), O> for F
+where
+    F: Fn(A, B) -> O,
+    E: std::error::Error + Send + Sync + 'static,
+    A: TryFrom<ArgValue<'static>, Error = E>,
+    B: TryFrom<ArgValue<'static>, Error = E>,
+{
+    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B)> {
+        debug_assert!(args.len() == 1);
+        let a: A = args[0].clone().try_into()?;
+        let b: B = args[1].clone().try_into()?;
+        Ok((a, b))
+    }
+
+    fn ret(o: O) -> RetValue {
+        o.into()
+    }
+}
+
+impl<F, A, B, C, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B, C), 
O> for F
+where
+    F: Fn(A, B, C) -> O,
+    E: std::error::Error + Send + Sync + 'static,
+    A: TryFrom<ArgValue<'static>, Error = E>,
+    B: TryFrom<ArgValue<'static>, Error = E>,
+    C: TryFrom<ArgValue<'static>, Error = E>,
+{
+    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B, C)> {
+        debug_assert!(args.len() == 1);
+        let a: A = args[0].clone().try_into()?;
+        let b: B = args[1].clone().try_into()?;
+        let c: C = args[2].clone().try_into()?;
+        Ok((a, b, c))
+    }
+
+    fn ret(o: O) -> RetValue {
+        o.into()
+    }
+}
+
+pub trait ToFunction<I, O>: Sized {
+    type Handle;
+
+    fn into_raw(self) -> *mut Self::Handle;
+
+    fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> 
anyhow::Result<RetValue>
+    where
+        Self: Typed<I, O>;
+
+    fn drop(handle: *mut Self::Handle);
+
+    fn to_function(self) -> Function
+    where
+        Self: Typed<I, O>,
+    {
+        let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+        let resource_handle = self.into_raw();
+        check_call!(ffi::TVMFuncCreateFromCFunc(
+            Some(Self::tvm_callback),
+            resource_handle as *mut _,
+            Some(Self::tvm_finalizer),
+            &mut fhandle as *mut _
+        ));
+        println!("fnhandle: {:?}", fhandle);
+        Function::new(fhandle)
+    }
+
+    /// The callback function which is wrapped converted by TVM
+    /// into a packed function stored in fhandle.
+    unsafe extern "C" fn tvm_callback(
+        args: *mut ffi::TVMValue,
+        type_codes: *mut c_int,
+        num_args: c_int,
+        ret: ffi::TVMRetValueHandle,
+        fhandle: *mut c_void,
+    ) -> c_int
+    where
+        Self: Typed<I, O>,
+    {
+        // turning off the incorrect linter complaints
+        #![allow(unused_assignments, unused_unsafe)]
+        println!("here");
+        let len = num_args as usize;
+        let args_list = slice::from_raw_parts_mut(args, len);
+        let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
+        let mut local_args: Vec<ArgValue> = Vec::new();
+        let mut value = MaybeUninit::uninit().assume_init();
+        let mut tcode = MaybeUninit::uninit().assume_init();
+        let rust_fn = fhandle as *mut Self::Handle;
+        for i in 0..len {
+            value = args_list[i];
+            println!("{:?}", value.v_handle);
+            tcode = type_codes_list[i];
+            if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int
+                || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
+                || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
+            {
+                check_call!(ffi::TVMCbArgToReturn(
+                    &mut value as *mut _,
+                    &mut tcode as *mut _
+                ));
+                println!("{:?}", value.v_handle);
+            }
+            let arg_value = ArgValue::from_tvm_value(value, tcode as u32);
+            println!("{:?}", arg_value);
+            local_args.push(arg_value);
+        }
+        println!("before call");
+        let rv = match Self::call(rust_fn, local_args.as_slice()) {
+            Ok(v) => v,
+            Err(msg) => {
+                crate::set_last_error(&msg);
+                return -1;
+            }
+        };
+        println!("after call");
+
+        let (mut ret_val, ret_tcode) = rv.to_tvm_value();
+        let mut ret_type_code = ret_tcode as c_int;
+        check_call!(ffi::TVMCFuncSetReturn(
+            ret,
+            &mut ret_val as *mut _,
+            &mut ret_type_code as *mut _,
+            1 as c_int
+        ));
+        0
+    }
+
+    /// The finalizer which is invoked when the packed function's
+    /// reference count is zero.
+    unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) {
+        let handle = std::mem::transmute(fhandle);
+        Self::drop(handle)
+    }
+}
+
+// /// A wrapper that is used to work around inference issues for bare 
functions.
+// ///
+// /// Used to implement `register_untyped`.
+// pub(self) struct RawFunction {
+//     fn_ptr: for<'a> fn (&'a [ArgValue<'static>]) -> Result<RetValue>
+// }
+
+// impl RawFunction {
+//     fn new(fn_ptr: for<'a> fn (&'a [ArgValue<'static>]) -> 
Result<RetValue>) -> RawFunction {
+//         RawFunction { fn_ptr: fn_ptr }
+//     }
+// }
+
+// impl Typed<&[ArgValue<'static>], ()> for RawFunction {
+//     fn args(i: &[ArgValue<'static>]) -> 
anyhow::Result<&[ArgValue<'static>]> {
+//         Ok(i)
+//     }
+
+//     fn ret(o: O) -> RetValue;
+// }
+
+// impl ToFunction<(), ()> for RawFunction
+// {
+//     type Handle = fn(&[ArgValue<'static>]) -> Result<RetValue>;
+
+//     fn into_raw(self) -> *mut Self::Handle {
+//         self.fn_ptr as *mut Self::Handle
+//     }
+
+//     fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> 
Result<RetValue> {
+//         let handle: Self::Handle = unsafe { std::mem::transmute(handle) };
+//         let r = handle(args);
+//         println!("afters");
+//         r
+//     }
+
+//     // Function's don't need de-allocation because the pointers are into 
the code section of memory.
+//     fn drop(_: *mut Self::Handle) {}
+// }
+
+impl<O, F> ToFunction<(), O> for F
+where
+    F: Fn() -> O + 'static,
+{
+    type Handle = Box<dyn Fn() -> O + 'static>;
+
+    fn into_raw(self) -> *mut Self::Handle {
+        let ptr: Box<Self::Handle> = Box::new(Box::new(self));
+        Box::into_raw(ptr)
+    }
+
+    fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> 
Result<RetValue>
+    where
+        F: Typed<(), O>,
+    {
+        // Ideally we shouldn't need to clone, probably doesn't really matter.
+        let out = unsafe { (*handle)() };
+        Ok(F::ret(out))
+    }
+
+    fn drop(_: *mut Self::Handle) {}
+}
+
+macro_rules! to_function_instance {
+    ($(($param:ident,$index:tt),)+) => {
+        impl<F, $($param,)+ O> ToFunction<($($param,)+), O> for
+        F where F: Fn($($param,)+) -> O + 'static {
+            type Handle = Box<dyn Fn($($param,)+) -> O + 'static>;
+
+            fn into_raw(self) -> *mut Self::Handle {
+                let ptr: Box<Self::Handle> = Box::new(Box::new(self));
+                Box::into_raw(ptr)
+            }
+
+            fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> 
Result<RetValue> where F: Typed<($($param,)+), O> {
+                // Ideally we shouldn't need to clone, probably doesn't really 
matter.
+                let args = F::args(args)?;
+                let out = unsafe {
+                    (*handle)($(args.$index),+)
+                };
+                Ok(F::ret(out))
+            }
+
+            fn drop(_: *mut Self::Handle) {}
+        }
+    }
+}
+
+to_function_instance!((A, 0),);
+to_function_instance!((A, 0), (B, 1),);
+to_function_instance!((A, 0), (B, 1), (C, 2),);
+to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),);
+
+#[cfg(test)]
+mod tests {
+    // use super::RawFunction;
+    use super::{Function, ToFunction, Typed};
+
+    fn zero() -> i32 {
+        10
+    }
+
+    fn helper<F, I, O>(f: F) -> Function
+    where
+        F: ToFunction<I, O>,
+        F: Typed<I, O>,
+    {
+        f.to_function()
+    }
+
+    // fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> {
+    //     Ok(10.into())
+    // }
+
+    // #[test]
+    // fn test_fn_ptr() {
+    //     let raw_fn = RawFunction::new(func_args);
+    //     raw_fn.to_function();
+    // }
+
+    #[test]
+    fn test_to_function0() {
+        helper(zero);
+    }
+
+    fn one_arg(i: i32) -> i32 {
+        i
+    }
+
+    #[test]
+    fn test_to_function1() {
+        helper(one_arg);
+    }
+
+    fn two_arg(i: i32, j: i32) -> i32 {
+        i + j
+    }
+
+    #[test]
+    fn test_to_function2() {
+        helper(two_arg);
+    }
+}
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
new file mode 100644
index 0000000..a9355e0
--- /dev/null
+++ b/rust/tvm-rt/src/value.rs
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module implements [`ArgValue`] and [`RetValue`] types
+//! and their conversions needed for the types used in frontend crate.
+//! `RetValue` is the owned version of `TVMPODValue`.
+
+use std::convert::TryFrom;
+// use std::ffi::c_void;
+
+use crate::{ArgValue, Function, Module, NDArray, RetValue};
+use tvm_sys::{
+    errors::ValueDowncastError,
+    ffi::{TVMFunctionHandle, TVMModuleHandle},
+    try_downcast,
+};
+
+macro_rules! impl_handle_val {
+    ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
+        impl<'a> From<&'a $type> for ArgValue<'a> {
+            fn from(arg: &'a $type) -> Self {
+                ArgValue::$variant(arg.handle() as $inner_type)
+            }
+        }
+
+        impl<'a> From<&'a mut $type> for ArgValue<'a> {
+            fn from(arg: &'a mut $type) -> Self {
+                ArgValue::$variant(arg.handle() as $inner_type)
+            }
+        }
+
+        impl<'a> TryFrom<ArgValue<'a>> for $type {
+            type Error = ValueDowncastError;
+            fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> {
+                try_downcast!(val -> $type, |ArgValue::$variant(val)| { 
$ctor(val) })
+            }
+        }
+
+        impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type {
+            type Error = ValueDowncastError;
+            fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> {
+                try_downcast!(val -> $type, |ArgValue::$variant(val)| { 
$ctor(*val) })
+            }
+        }
+
+        impl From<$type> for RetValue {
+            fn from(val: $type) -> RetValue {
+                RetValue::$variant(val.handle() as $inner_type)
+            }
+        }
+
+        impl TryFrom<RetValue> for $type {
+            type Error = ValueDowncastError;
+            fn try_from(val: RetValue) -> Result<$type, Self::Error> {
+                try_downcast!(val -> $type, |RetValue::$variant(val)| { 
$ctor(val) })
+            }
+        }
+    };
+}
+
+impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
+impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
+
+impl<'a> From<&'a NDArray> for ArgValue<'a> {
+    fn from(arg: &'a NDArray) -> Self {
+        match arg {
+            &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle),
+            &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle),
+        }
+    }
+}
+
+impl<'a> From<&'a mut NDArray> for ArgValue<'a> {
+    fn from(arg: &'a mut NDArray) -> Self {
+        match arg {
+            &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle),
+            &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle),
+        }
+    }
+}
+
+impl<'a> TryFrom<ArgValue<'a>> for NDArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: ArgValue<'a>) -> Result<NDArray, Self::Error> {
+        try_downcast!(val -> NDArray,
+            |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) 
},
+            |ArgValue::ArrayHandle(val)| { NDArray::new(val) })
+    }
+}
+
+impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: &'a ArgValue<'v>) -> Result<NDArray, Self::Error> {
+        try_downcast!(val -> NDArray,
+            |ArgValue::NDArrayHandle(val)| { 
NDArray::from_ndarray_handle(*val) },
+            |ArgValue::ArrayHandle(val)| { NDArray::new(*val) })
+    }
+}
+
+impl From<NDArray> for RetValue {
+    fn from(val: NDArray) -> RetValue {
+        match val {
+            NDArray::Owned { handle } => RetValue::NDArrayHandle(handle),
+            _ => panic!("NYI"),
+        }
+    }
+}
+
+impl TryFrom<RetValue> for NDArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: RetValue) -> Result<NDArray, Self::Error> {
+        try_downcast!(val -> NDArray,
+            |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) 
},
+            |RetValue::ArrayHandle(val)| { NDArray::new(val) })
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::{convert::TryInto, str::FromStr};
+
+    use crate::{ByteArray, Context, DataType};
+
+    use super::*;
+
+    #[test]
+    fn bytearray() {
+        let w = vec![1u8, 2, 3, 4, 5];
+        let v = ByteArray::from(w.as_slice());
+        let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
+        assert_eq!(
+            tvm.data(),
+            w.iter().copied().collect::<Vec<u8>>().as_slice()
+        );
+    }
+
+    #[test]
+    fn ty() {
+        let t = DataType::from_str("int32").unwrap();
+        let tvm: DataType = RetValue::from(t).try_into().unwrap();
+        assert_eq!(tvm, t);
+    }
+
+    #[test]
+    fn ctx() {
+        let c = Context::from_str("gpu").unwrap();
+        let tvm: Context = RetValue::from(c).try_into().unwrap();
+        assert_eq!(tvm, c);
+    }
+}
diff --git a/rust/tvm-rt/tests/test_ir.rs b/rust/tvm-rt/tests/test_ir.rs
new file mode 100644
index 0000000..7d9e475
--- /dev/null
+++ b/rust/tvm-rt/tests/test_ir.rs
@@ -0,0 +1,36 @@
+// use std::convert::TryInto;
+// use std::str::FromStr;
+// use tvm_rt::string::String as TString;
+// use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef};
+// use tvm::{call_packed, DLDataType, Function};
+// use tvm_sys::RetValue;
+
+// #[test]
+// fn test_new_object() -> anyhow::Result<()> {
+//     let object = Object::base_object::<Object>();
+//     let ptr = ObjectPtr::new(object);
+//     assert_eq!(ptr.count(), 1);
+//     Ok(())
+// }
+
+// #[test]
+// fn test_new_string() -> anyhow::Result<()> {
+//     let string = TString::new("hello world!".to_string())?;
+//     Ok(())
+// }
+
+// #[test]
+// fn test_obj_build() -> anyhow::Result<()> {
+//     let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not 
found.");
+
+//     let dt = DLDataType::from_str("int32").expect("Known datatype doesn't 
convert.");
+
+//     let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337)
+//         .expect("foo")
+//         .try_into()
+//         .unwrap();
+
+//     debug_print(&ret_val);
+
+//     Ok(())
+// }
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 7272213..b322388 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -162,7 +162,7 @@ GlobalVar::GlobalVar(std::string name_hint) {
 TVM_REGISTER_NODE_TYPE(GlobalVarNode);
 
 TVM_REGISTER_GLOBAL("ir.GlobalVar")
-.set_body_typed([](std::string name){
+.set_body_typed([](String name){
   return GlobalVar(name);
 });
 
@@ -214,4 +214,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     }
     p->stream << '}';
   });
+
+TVM_REGISTER_GLOBAL("ir.DebugPrinter")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  ObjectRef ref = args[0];
+  std::stringstream ss;
+  ss << ref;
+  *ret = ss.str();
+});
+
 }  // namespace tvm
diff --git a/src/printer/relay_text_printer.cc 
b/src/printer/relay_text_printer.cc
index bda997a..fc9546a 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -193,8 +193,7 @@ class RelayTextPrinter :
     case kTypeData:
       return Doc::Text("TypeData");
     default:
-      LOG(ERROR) << "Unknown Kind";
-      throw;
+      CHECK(false) << "Unknown Kind";
     }
   }
   /*!
@@ -479,7 +478,8 @@ class RelayTextPrinter :
   }
 
   Doc VisitExpr_(const GlobalVarNode* op) final {
-    return Doc::Text('@' + op->name_hint);
+    std::string name_hint = op->name_hint;
+    return Doc::Text('@' + name_hint);
   }
 
   Doc VisitExpr_(const OpNode* op) final {
@@ -939,4 +939,13 @@ TVM_REGISTER_GLOBAL("ir.PrettyPrint")
 
 TVM_REGISTER_GLOBAL("ir.AsText")
 .set_body_typed(AsText);
+
+TVM_REGISTER_GLOBAL("ir.TextPrinter")
+.set_body_typed([](ObjectRef node) {
+  std::cout << "The program: " << node << std::endl;
+  auto text = AsText(node, false, nullptr);
+  std::cout << "The text " << text;
+  return text;
+});
+
 }  // namespace tvm
diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc
index e6c8392..65ee57f 100644
--- a/src/relay/transforms/to_cps.cc
+++ b/src/relay/transforms/to_cps.cc
@@ -164,7 +164,7 @@ Function ToCPS(const Function& f,
         // only look unfold non-external calls.
         BaseFunc base_func = m->Lookup(gv);
         if (auto* n = base_func.as<FunctionNode>()) {
-          auto cps_gv = GlobalVar(gv->name_hint + "_cps");
+          auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps");
           cm->insert({gv, cps_gv});
           m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
         } else {
diff --git a/src/runtime/object.cc b/src/runtime/object.cc
index 0301200..5496159 100644
--- a/src/runtime/object.cc
+++ b/src/runtime/object.cc
@@ -244,12 +244,26 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* 
out_tindex) {
   API_END();
 }
 
+int TVMObjectRetain(TVMObjectHandle obj) {
+  API_BEGIN();
+  tvm::runtime::ObjectInternal::ObjectRetain(obj);
+  API_END();
+}
+
 int TVMObjectFree(TVMObjectHandle obj) {
   API_BEGIN();
   tvm::runtime::ObjectInternal::ObjectFree(obj);
   API_END();
 }
 
+
+int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t 
parent_type_index, int* is_derived) {
+  API_BEGIN();
+  *is_derived = tvm::runtime::TypeContext::Global()->
+    DerivedFrom(child_type_index, parent_type_index);
+  API_END();
+}
+
 int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
   API_BEGIN();
   out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(
diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h
index 7955130..ab48802 100644
--- a/src/runtime/object_internal.h
+++ b/src/runtime/object_internal.h
@@ -38,6 +38,15 @@ namespace runtime {
 class ObjectInternal {
  public:
   /*!
+   * \brief Retain an object handle.
+   */
+  static void ObjectRetain(TVMObjectHandle obj) {
+    if (obj != nullptr) {
+      static_cast<Object*>(obj)->IncRef();
+    }
+  }
+
+  /*!
    * \brief Free an object handle.
    */
   static void ObjectFree(TVMObjectHandle obj) {

Reply via email to