This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 09b989d  [Rust][Fix] Memory leak (#8714)
09b989d is described below

commit 09b989deb77cfb40f468c2566d1f40227af44bf7
Author: Jared Roesch <[email protected]>
AuthorDate: Wed Aug 11 05:57:52 2021 -0700

    [Rust][Fix] Memory leak (#8714)
    
    * Fix obvious memory leak in function.rs
    
    * Update object pointer
---
 rust/tvm-rt/src/function.rs          | 89 ++++++++++++++++--------------------
 rust/tvm-rt/src/module.rs            |  2 +-
 rust/tvm-rt/src/ndarray.rs           |  4 +-
 rust/tvm-rt/src/object/object_ptr.rs | 12 -----
 rust/tvm-rt/src/to_function.rs       |  2 +-
 rust/tvm-sys/build.rs                | 10 +++-
 rust/tvm/tests/basics/src/main.rs    |  2 +-
 7 files changed, 52 insertions(+), 69 deletions(-)

diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index aec4a8a..5db665c 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -26,6 +26,7 @@
 //! See the tests and examples repository for more examples.
 
 use std::convert::{TryFrom, TryInto};
+use std::sync::Arc;
 use std::{
     ffi::CString,
     os::raw::{c_char, c_int},
@@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
 
 pub type Result<T> = std::result::Result<T, Error>;
 
-/// Wrapper around TVM function handle which includes `is_global`
-/// indicating whether the function is global or not, and `is_cloned` showing
-/// not to drop a cloned function from Rust side.
-/// The value of these fields can be accessed through their respective methods.
 #[derive(Debug, Hash)]
-pub struct Function {
-    pub(crate) handle: ffi::TVMFunctionHandle,
-    // whether the registered function is global or not.
-    is_global: bool,
-    from_rust: bool,
+struct FunctionPtr {
+    handle: ffi::TVMFunctionHandle,
 }
 
-unsafe impl Send for Function {}
-unsafe impl Sync for Function {}
+// NB(@jroesch): I think this is ok, need to double check,
+// if not we should mutex the pointer or move to Rc.
+unsafe impl Send for FunctionPtr {}
+unsafe impl Sync for FunctionPtr {}
+
+impl FunctionPtr {
+    fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
+        FunctionPtr { handle }
+    }
+}
+
+impl Drop for FunctionPtr {
+    fn drop(&mut self) {
+        check_call!(ffi::TVMFuncFree(self.handle));
+    }
+}
+
+/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust.
+#[derive(Debug, Hash)]
+pub struct Function {
+    inner: Arc<FunctionPtr>,
+}
 
 impl Function {
-    pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
+    pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
         Function {
-            handle,
-            is_global: false,
-            from_rust: false,
+            inner: Arc::new(FunctionPtr::from_raw(handle)),
         }
     }
 
     pub unsafe fn null() -> Self {
-        Function {
-            handle: std::ptr::null_mut(),
-            is_global: false,
-            from_rust: false,
-        }
+        Function::from_raw(std::ptr::null_mut())
     }
 
     /// For a given function, it returns a function by name.
@@ -84,11 +92,7 @@ impl Function {
         if handle.is_null() {
             None
         } else {
-            Some(Function {
-                handle,
-                is_global: true,
-                from_rust: false,
-            })
+            Some(Function::from_raw(handle))
         }
     }
 
@@ -103,12 +107,7 @@ impl Function {
 
     /// Returns the underlying TVM function handle.
     pub fn handle(&self) -> ffi::TVMFunctionHandle {
-        self.handle
-    }
-
-    /// Returns `true` if the underlying TVM function is global and `false` 
otherwise.
-    pub fn is_global(&self) -> bool {
-        self.is_global
+        self.inner.handle
     }
 
     /// Calls the function that created from `Builder`.
@@ -122,7 +121,7 @@ impl Function {
 
         let ret_code = unsafe {
             ffi::TVMFuncCall(
-                self.handle,
+                self.handle(),
                 values.as_mut_ptr() as *mut ffi::TVMValue,
                 type_codes.as_mut_ptr() as *mut c_int,
                 num_args as c_int,
@@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,);
 
 impl Clone for Function {
     fn clone(&self) -> Function {
-        Self {
-            handle: self.handle,
-            is_global: self.is_global,
-            from_rust: true,
+        Function {
+            inner: self.inner.clone(),
         }
     }
 }
 
-// impl Drop for Function {
-//     fn drop(&mut self) {
-//         if !self.is_global && !self.is_cloned {
-//             check_call!(ffi::TVMFuncFree(self.handle));
-//         }
-//     }
-// }
-
 impl From<Function> for RetValue {
     fn from(func: Function) -> RetValue {
-        RetValue::FuncHandle(func.handle)
+        RetValue::FuncHandle(func.handle())
     }
 }
 
@@ -198,7 +187,7 @@ impl TryFrom<RetValue> for Function {
 
     fn try_from(ret_value: RetValue) -> Result<Function> {
         match ret_value {
-            RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
+            RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
             _ => Err(Error::downcast(
                 format!("{:?}", ret_value),
                 "FunctionHandle",
@@ -209,10 +198,10 @@ impl TryFrom<RetValue> for Function {
 
 impl<'a> From<Function> for ArgValue<'a> {
     fn from(func: Function) -> ArgValue<'a> {
-        if func.handle.is_null() {
+        if func.handle().is_null() {
             ArgValue::Null
         } else {
-            ArgValue::FuncHandle(func.handle)
+            ArgValue::FuncHandle(func.handle())
         }
     }
 }
@@ -222,7 +211,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
 
     fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
         match arg_value {
-            ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
+            ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
             _ => Err(Error::downcast(
                 format!("{:?}", arg_value),
                 "FunctionHandle",
@@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
 
     fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
         match arg_value {
-            ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
+            ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)),
             _ => Err(Error::downcast(
                 format!("{:?}", arg_value),
                 "FunctionHandle",
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index 343f0dc..8d59c2a 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -82,7 +82,7 @@ impl Module {
             return 
Err(errors::Error::NullHandle(name.into_string()?.to_string()));
         }
 
-        Ok(Function::new(fhandle))
+        Ok(Function::from_raw(fhandle))
     }
 
     /// Imports a dependent module such as `.ptx` for cuda gpu.
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index 0e2d283..08dcfe3 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -61,7 +61,7 @@ use num_traits::Num;
 
 use crate::errors::NDArrayError;
 
-use crate::object::{Object, ObjectPtr};
+use crate::object::{Object, ObjectPtr, ObjectRef};
 
 /// See the [`module-level documentation`](../ndarray/index.html) for more 
details.
 #[repr(C)]
@@ -73,7 +73,7 @@ pub struct NDArrayContainer {
     // Container Base
     dl_tensor: DLTensor,
     manager_ctx: *mut c_void,
-    // TOOD: shape?
+    shape: ObjectRef,
 }
 
 impl NDArrayContainer {
diff --git a/rust/tvm-rt/src/object/object_ptr.rs 
b/rust/tvm-rt/src/object/object_ptr.rs
index 64fd6a2..a093cf5 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -148,18 +148,6 @@ impl Object {
     }
 }
 
-// impl fmt::Debug for Object {
-//     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-//         let index =
-//             format!("{} // key: {}", self.type_index, "the_key");
-
-//         f.debug_struct("Object")
-//          .field("type_index", &index)
-//          // TODO(@jroesch: do we expose other fields?)
-//          .finish()
-//     }
-// }
-
 /// An unsafe trait which should be implemented for an object
 /// subtype.
 ///
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index c5ede7d..7797d2c 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -74,7 +74,7 @@ pub trait ToFunction<I, O>: Sized {
             &mut fhandle as *mut ffi::TVMFunctionHandle,
         ));
 
-        Function::new(fhandle)
+        Function::from_raw(fhandle)
     }
 
     /// The callback function which is wrapped converted by TVM
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 930ee59..7793f9f 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -19,7 +19,10 @@
 
 extern crate bindgen;
 
-use std::{path::{Path, PathBuf}, str::FromStr};
+use std::{
+    path::{Path, PathBuf},
+    str::FromStr,
+};
 
 use anyhow::{Context, Result};
 use tvm_build::{BuildConfig, CMakeSetting};
@@ -195,7 +198,10 @@ fn find_using_tvm_build() -> Result<TVMInstall> {
     if cfg!(feature = "use-vitis-ai") {
         build_config.settings.use_vitis_ai = Some(true);
     }
-    if cfg!(any(feature = "static-linking", feature = "build-static-runtime")) 
{
+    if cfg!(any(
+        feature = "static-linking",
+        feature = "build-static-runtime"
+    )) {
         build_config.settings.build_static_runtime = Some(true);
     }
 
diff --git a/rust/tvm/tests/basics/src/main.rs 
b/rust/tvm/tests/basics/src/main.rs
index 2e0f5b5..b7c3036 100644
--- a/rust/tvm/tests/basics/src/main.rs
+++ b/rust/tvm/tests/basics/src/main.rs
@@ -35,7 +35,7 @@ fn main() {
     let mut arr = NDArray::empty(shape, dev, dtype);
     arr.copy_from_buffer(data.as_mut_slice());
     let ret = NDArray::empty(shape, dev, dtype);
-    let mut fadd = Module::load(&concat!(env!("OUT_DIR"), 
"/test_add.so")).unwrap();
+    let fadd = Module::load(&concat!(env!("OUT_DIR"), 
"/test_add.so")).unwrap();
     if !fadd.enabled(dev_name) {
         return;
     }

Reply via email to