This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0210ce2  Fix reference cycle memory leak in gluon.Trainer (#18363)
0210ce2 is described below

commit 0210ce2c136afaa0f57666e5e1c659cab353f5f3
Author: Leonard Lausen <[email protected]>
AuthorDate: Wed May 20 15:29:40 2020 -0700

    Fix reference cycle memory leak in gluon.Trainer (#18363)
---
 conftest.py                                 | 48 +++++++++++++++++++++++++++++
 python/mxnet/gluon/parameter.py             | 26 +++++++++++-----
 tests/python/unittest/test_gluon.py         | 37 ++--------------------
 tests/python/unittest/test_gluon_trainer.py |  2 +-
 4 files changed, 70 insertions(+), 43 deletions(-)

diff --git a/conftest.py b/conftest.py
index caabaf9..7a0aa44 100644
--- a/conftest.py
+++ b/conftest.py
@@ -24,6 +24,7 @@ instantiated before lower-scoped fixtures (such as 
``function``).
 """
 
 import logging
+import gc
 import os
 import random
 
@@ -229,3 +230,50 @@ def doctest(doctest_namespace):
         logging.warning('Unable to import numpy/mxnet. Skipping conftest.')
     import doctest
     doctest.ELLIPSIS_MARKER = '-etc-'
+
+
[email protected](scope='session')
+def mxnet_module():
+    import mxnet
+    return mxnet
+
+
[email protected]()
+# @pytest.fixture(autouse=True)  # Fix all the bugs and mark this autouse=True
+def check_leak_ndarray(mxnet_module):
+    # Collect garbage prior to running the next test
+    gc.collect()
+    # Enable gc debug mode to check if the test leaks any arrays
+    gc_flags = gc.get_debug()
+    gc.set_debug(gc.DEBUG_SAVEALL)
+
+    # Run the test
+    yield
+
+    # Check for leaked NDArrays
+    gc.collect()
+    gc.set_debug(gc_flags)  # reset gc flags
+
+    seen = set()
+    def has_array(element):
+        try:
+            if element in seen:
+                return False
+            seen.add(element)
+        except TypeError:  # unhashable
+            pass
+
+        if isinstance(element, mxnet_module.nd._internal.NDArrayBase):
+            return True
+        elif hasattr(element, '__dict__'):
+            return any(has_array(x) for x in vars(element))
+        elif isinstance(element, dict):
+            return any(has_array(x) for x in element.items())
+        else:
+            try:
+                return any(has_array(x) for x in element)
+            except (TypeError, KeyError):
+                return False
+
+    assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays 
due to reference cycles'
+    del gc.garbage[:]
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 06b6150..9600d83 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -25,6 +25,7 @@ __all__ = ['DeferredInitializationError', 'Parameter', 
'Constant',
 
 from collections import OrderedDict, defaultdict
 import warnings
+import weakref
 import numpy as np
 
 from ..base import mx_real_t, MXNetError
@@ -201,12 +202,15 @@ class Parameter(object):
     def _set_trainer(self, trainer):
         """ Set the trainer this parameter is associated with. """
         # trainer cannot be replaced for sparse params
-        if self._stype != 'default' and self._trainer and trainer and 
self._trainer is not trainer:
+        if self._stype != 'default' and self._trainer and trainer and 
self._trainer() is not trainer:
             raise RuntimeError(
                 "Failed to set the trainer for Parameter '%s' because it was 
already set. " \
                 "More than one trainers for a %s Parameter is not supported." \
                 %(self.name, self._stype))
-        self._trainer = trainer
+        if trainer is not None:
+            self._trainer = weakref.ref(trainer)
+        else:
+            self._trainer = trainer
 
     def _check_and_get(self, arr_list, ctx):
         if arr_list is not None:
@@ -245,13 +249,14 @@ class Parameter(object):
         # get row sparse params based on row ids
         if not isinstance(row_id, ndarray.NDArray):
             raise TypeError("row_id must have NDArray type, but %s is 
given"%(type(row_id)))
-        if not self._trainer:
+        trainer = self._trainer() if self._trainer else None
+        if not trainer:
             raise RuntimeError("Cannot get row_sparse data for Parameter '%s' 
when no " \
                                "Trainer is created with it."%self.name)
         results = self._check_and_get(arr_list, ctx)
 
         # fetch row sparse params from the trainer
-        self._trainer._row_sparse_pull(self, results, row_id)
+        trainer._row_sparse_pull(self, results, row_id)
         return results
 
     def _load_init(self, data, ctx, cast_dtype=False, dtype_source='current'):
@@ -397,7 +402,11 @@ class Parameter(object):
             # fetch all rows for 'row_sparse' param
             all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', 
ctx=ctx)
             data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
-            self._trainer._row_sparse_pull(self, data, all_row_ids, 
full_idx=True)
+            trainer = self._trainer() if self._trainer else None
+            if not trainer:
+                raise RuntimeError("Cannot reduce row_sparse data for 
Parameter '%s' when no " \
+                                   "Trainer is created with it."%self.name)
+            trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
         return data
 
     def initialize(self, init=None, ctx=None, 
default_init=initializer.Uniform(),
@@ -503,9 +512,10 @@ class Parameter(object):
             return
 
         # if update_on_kvstore, we need to make sure the copy stored in 
kvstore is in sync
-        if self._trainer and self._trainer._kv_initialized and 
self._trainer._update_on_kvstore:
-            if self not in self._trainer._params_to_init:
-                self._trainer._reset_kvstore()
+        trainer = self._trainer() if self._trainer else None
+        if trainer and trainer._kv_initialized and trainer._update_on_kvstore:
+            if self not in trainer._params_to_init:
+                trainer._reset_kvstore()
 
         for arr in self._check_and_get(self._data, list):
             arr[:] = data
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 98773b2..da13aff 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -99,6 +99,7 @@ def test_parameter_invalid_access():
     assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)
 
 @with_seed()
[email protected]("check_leak_ndarray")
 def test_parameter_dict():
     ctx = mx.cpu(1)
     params0 = gluon.ParameterDict('net_')
@@ -3226,40 +3227,8 @@ def test_reqs_switching_training_inference():
 
     mx.test_utils.assert_almost_equal(grad1, grad2)
 
-def test_no_memory_leak_in_gluon():
-    # Collect all other garbage prior to this test. Otherwise the test may fail
-    # due to unrelated memory leaks.
-    gc.collect()
 
-    gc_flags = gc.get_debug()
-    gc.set_debug(gc.DEBUG_SAVEALL)
[email protected]("check_leak_ndarray")
+def test_no_memory_leak_in_gluon():
     net = mx.gluon.nn.Dense(10, in_units=10)
     net.initialize()
-    del net
-    gc.collect()
-    gc.set_debug(gc_flags)  # reset gc flags
-
-    # Check for leaked NDArrays
-    seen = set()
-    def has_array(element):
-        try:
-            if element in seen:
-                return False
-            seen.add(element)
-        except TypeError:  # unhashable
-            pass
-
-        if isinstance(element, mx.nd._internal.NDArrayBase):
-            return True
-        elif hasattr(element, '__dict__'):
-            return any(has_array(x) for x in vars(element))
-        elif isinstance(element, dict):
-            return any(has_array(x) for x in element.items())
-        else:
-            try:
-                return any(has_array(x) for x in element)
-            except (TypeError, KeyError):
-                return False
-
-    assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays 
due to reference cycles'
-    del gc.garbage[:]
diff --git a/tests/python/unittest/test_gluon_trainer.py 
b/tests/python/unittest/test_gluon_trainer.py
index 892b2e3..874ab8c 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -37,7 +37,7 @@ def test_multi_trainer():
     x.initialize()
     # test set trainer
     trainer0 = gluon.Trainer([x], 'sgd')
-    assert(x._trainer is trainer0)
+    assert(x._trainer() is trainer0)
     # test unset trainer
     x._set_trainer(None)
     assert(x._trainer is None)

Reply via email to