This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 09b989d [Rust][Fix] Memory leak (#8714)
09b989d is described below
commit 09b989deb77cfb40f468c2566d1f40227af44bf7
Author: Jared Roesch <[email protected]>
AuthorDate: Wed Aug 11 05:57:52 2021 -0700
[Rust][Fix] Memory leak (#8714)
* Fix obvious memory leak in function.rs
* Update object pointer
---
rust/tvm-rt/src/function.rs | 89 ++++++++++++++++--------------------
rust/tvm-rt/src/module.rs | 2 +-
rust/tvm-rt/src/ndarray.rs | 4 +-
rust/tvm-rt/src/object/object_ptr.rs | 12 -----
rust/tvm-rt/src/to_function.rs | 2 +-
rust/tvm-sys/build.rs | 10 +++-
rust/tvm/tests/basics/src/main.rs | 2 +-
7 files changed, 52 insertions(+), 69 deletions(-)
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index aec4a8a..5db665c 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -26,6 +26,7 @@
//! See the tests and examples repository for more examples.
use std::convert::{TryFrom, TryInto};
+use std::sync::Arc;
use std::{
ffi::CString,
os::raw::{c_char, c_int},
@@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
pub type Result<T> = std::result::Result<T, Error>;
-/// Wrapper around TVM function handle which includes `is_global`
-/// indicating whether the function is global or not, and `is_cloned` showing
-/// not to drop a cloned function from Rust side.
-/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
-pub struct Function {
- pub(crate) handle: ffi::TVMFunctionHandle,
- // whether the registered function is global or not.
- is_global: bool,
- from_rust: bool,
+struct FunctionPtr {
+ handle: ffi::TVMFunctionHandle,
}
-unsafe impl Send for Function {}
-unsafe impl Sync for Function {}
+// NB(@jroesch): I think this is ok, need to double check,
+// if not we should mutex the pointer or move to Rc.
+unsafe impl Send for FunctionPtr {}
+unsafe impl Sync for FunctionPtr {}
+
+impl FunctionPtr {
+ fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
+ FunctionPtr { handle }
+ }
+}
+
+impl Drop for FunctionPtr {
+ fn drop(&mut self) {
+ check_call!(ffi::TVMFuncFree(self.handle));
+ }
+}
+
+/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust.
+#[derive(Debug, Hash)]
+pub struct Function {
+ inner: Arc<FunctionPtr>,
+}
impl Function {
- pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
+ pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
Function {
- handle,
- is_global: false,
- from_rust: false,
+ inner: Arc::new(FunctionPtr::from_raw(handle)),
}
}
pub unsafe fn null() -> Self {
- Function {
- handle: std::ptr::null_mut(),
- is_global: false,
- from_rust: false,
- }
+ Function::from_raw(std::ptr::null_mut())
}
/// For a given function, it returns a function by name.
@@ -84,11 +92,7 @@ impl Function {
if handle.is_null() {
None
} else {
- Some(Function {
- handle,
- is_global: true,
- from_rust: false,
- })
+ Some(Function::from_raw(handle))
}
}
@@ -103,12 +107,7 @@ impl Function {
/// Returns the underlying TVM function handle.
pub fn handle(&self) -> ffi::TVMFunctionHandle {
- self.handle
- }
-
- /// Returns `true` if the underlying TVM function is global and `false`
otherwise.
- pub fn is_global(&self) -> bool {
- self.is_global
+ self.inner.handle
}
/// Calls the function that created from `Builder`.
@@ -122,7 +121,7 @@ impl Function {
let ret_code = unsafe {
ffi::TVMFuncCall(
- self.handle,
+ self.handle(),
values.as_mut_ptr() as *mut ffi::TVMValue,
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
@@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,);
impl Clone for Function {
fn clone(&self) -> Function {
- Self {
- handle: self.handle,
- is_global: self.is_global,
- from_rust: true,
+ Function {
+ inner: self.inner.clone(),
}
}
}
-// impl Drop for Function {
-// fn drop(&mut self) {
-// if !self.is_global && !self.is_cloned {
-// check_call!(ffi::TVMFuncFree(self.handle));
-// }
-// }
-// }
-
impl From<Function> for RetValue {
fn from(func: Function) -> RetValue {
- RetValue::FuncHandle(func.handle)
+ RetValue::FuncHandle(func.handle())
}
}
@@ -198,7 +187,7 @@ impl TryFrom<RetValue> for Function {
fn try_from(ret_value: RetValue) -> Result<Function> {
match ret_value {
- RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
+ RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
_ => Err(Error::downcast(
format!("{:?}", ret_value),
"FunctionHandle",
@@ -209,10 +198,10 @@ impl TryFrom<RetValue> for Function {
impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
- if func.handle.is_null() {
+ if func.handle().is_null() {
ArgValue::Null
} else {
- ArgValue::FuncHandle(func.handle)
+ ArgValue::FuncHandle(func.handle())
}
}
}
@@ -222,7 +211,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
match arg_value {
- ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
+ ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
@@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
match arg_value {
- ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
+ ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index 343f0dc..8d59c2a 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -82,7 +82,7 @@ impl Module {
return
Err(errors::Error::NullHandle(name.into_string()?.to_string()));
}
- Ok(Function::new(fhandle))
+ Ok(Function::from_raw(fhandle))
}
/// Imports a dependent module such as `.ptx` for cuda gpu.
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index 0e2d283..08dcfe3 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -61,7 +61,7 @@ use num_traits::Num;
use crate::errors::NDArrayError;
-use crate::object::{Object, ObjectPtr};
+use crate::object::{Object, ObjectPtr, ObjectRef};
/// See the [`module-level documentation`](../ndarray/index.html) for more
details.
#[repr(C)]
@@ -73,7 +73,7 @@ pub struct NDArrayContainer {
// Container Base
dl_tensor: DLTensor,
manager_ctx: *mut c_void,
- // TOOD: shape?
+ shape: ObjectRef,
}
impl NDArrayContainer {
diff --git a/rust/tvm-rt/src/object/object_ptr.rs
b/rust/tvm-rt/src/object/object_ptr.rs
index 64fd6a2..a093cf5 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -148,18 +148,6 @@ impl Object {
}
}
-// impl fmt::Debug for Object {
-// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-// let index =
-// format!("{} // key: {}", self.type_index, "the_key");
-
-// f.debug_struct("Object")
-// .field("type_index", &index)
-// // TODO(@jroesch: do we expose other fields?)
-// .finish()
-// }
-// }
-
/// An unsafe trait which should be implemented for an object
/// subtype.
///
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index c5ede7d..7797d2c 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -74,7 +74,7 @@ pub trait ToFunction<I, O>: Sized {
&mut fhandle as *mut ffi::TVMFunctionHandle,
));
- Function::new(fhandle)
+ Function::from_raw(fhandle)
}
/// The callback function which is wrapped converted by TVM
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 930ee59..7793f9f 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -19,7 +19,10 @@
extern crate bindgen;
-use std::{path::{Path, PathBuf}, str::FromStr};
+use std::{
+ path::{Path, PathBuf},
+ str::FromStr,
+};
use anyhow::{Context, Result};
use tvm_build::{BuildConfig, CMakeSetting};
@@ -195,7 +198,10 @@ fn find_using_tvm_build() -> Result<TVMInstall> {
if cfg!(feature = "use-vitis-ai") {
build_config.settings.use_vitis_ai = Some(true);
}
- if cfg!(any(feature = "static-linking", feature = "build-static-runtime"))
{
+ if cfg!(any(
+ feature = "static-linking",
+ feature = "build-static-runtime"
+ )) {
build_config.settings.build_static_runtime = Some(true);
}
diff --git a/rust/tvm/tests/basics/src/main.rs
b/rust/tvm/tests/basics/src/main.rs
index 2e0f5b5..b7c3036 100644
--- a/rust/tvm/tests/basics/src/main.rs
+++ b/rust/tvm/tests/basics/src/main.rs
@@ -35,7 +35,7 @@ fn main() {
let mut arr = NDArray::empty(shape, dev, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let ret = NDArray::empty(shape, dev, dtype);
- let mut fadd = Module::load(&concat!(env!("OUT_DIR"),
"/test_add.so")).unwrap();
+ let fadd = Module::load(&concat!(env!("OUT_DIR"),
"/test_add.so")).unwrap();
if !fadd.enabled(dev_name) {
return;
}