This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0210ce2 Fix reference cycle memory leak in gluon.Trainer (#18363)
0210ce2 is described below
commit 0210ce2c136afaa0f57666e5e1c659cab353f5f3
Author: Leonard Lausen <[email protected]>
AuthorDate: Wed May 20 15:29:40 2020 -0700
Fix reference cycle memory leak in gluon.Trainer (#18363)
---
conftest.py | 48 +++++++++++++++++++++++++++++
python/mxnet/gluon/parameter.py | 26 +++++++++++-----
tests/python/unittest/test_gluon.py | 37 ++--------------------
tests/python/unittest/test_gluon_trainer.py | 2 +-
4 files changed, 70 insertions(+), 43 deletions(-)
diff --git a/conftest.py b/conftest.py
index caabaf9..7a0aa44 100644
--- a/conftest.py
+++ b/conftest.py
@@ -24,6 +24,7 @@ instantiated before lower-scoped fixtures (such as
``function``).
"""
import logging
+import gc
import os
import random
@@ -229,3 +230,50 @@ def doctest(doctest_namespace):
logging.warning('Unable to import numpy/mxnet. Skipping conftest.')
import doctest
doctest.ELLIPSIS_MARKER = '-etc-'
+
+
[email protected](scope='session')
+def mxnet_module():
+ import mxnet
+ return mxnet
+
+
[email protected]()
+# @pytest.fixture(autouse=True) # Fix all the bugs and mark this autouse=True
+def check_leak_ndarray(mxnet_module):
+ # Collect garbage prior to running the next test
+ gc.collect()
+ # Enable gc debug mode to check if the test leaks any arrays
+ gc_flags = gc.get_debug()
+ gc.set_debug(gc.DEBUG_SAVEALL)
+
+ # Run the test
+ yield
+
+ # Check for leaked NDArrays
+ gc.collect()
+ gc.set_debug(gc_flags) # reset gc flags
+
+ seen = set()
+ def has_array(element):
+ try:
+ if element in seen:
+ return False
+ seen.add(element)
+ except TypeError: # unhashable
+ pass
+
+ if isinstance(element, mxnet_module.nd._internal.NDArrayBase):
+ return True
+ elif hasattr(element, '__dict__'):
+ return any(has_array(x) for x in vars(element))
+ elif isinstance(element, dict):
+ return any(has_array(x) for x in element.items())
+ else:
+ try:
+ return any(has_array(x) for x in element)
+ except (TypeError, KeyError):
+ return False
+
+ assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays
due to reference cycles'
+ del gc.garbage[:]
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 06b6150..9600d83 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -25,6 +25,7 @@ __all__ = ['DeferredInitializationError', 'Parameter',
'Constant',
from collections import OrderedDict, defaultdict
import warnings
+import weakref
import numpy as np
from ..base import mx_real_t, MXNetError
@@ -201,12 +202,15 @@ class Parameter(object):
def _set_trainer(self, trainer):
""" Set the trainer this parameter is associated with. """
# trainer cannot be replaced for sparse params
- if self._stype != 'default' and self._trainer and trainer and
self._trainer is not trainer:
+ if self._stype != 'default' and self._trainer and trainer and
self._trainer() is not trainer:
raise RuntimeError(
"Failed to set the trainer for Parameter '%s' because it was
already set. " \
"More than one trainers for a %s Parameter is not supported." \
%(self.name, self._stype))
- self._trainer = trainer
+ if trainer is not None:
+ self._trainer = weakref.ref(trainer)
+ else:
+ self._trainer = trainer
def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
@@ -245,13 +249,14 @@ class Parameter(object):
# get row sparse params based on row ids
if not isinstance(row_id, ndarray.NDArray):
raise TypeError("row_id must have NDArray type, but %s is
given"%(type(row_id)))
- if not self._trainer:
+ trainer = self._trainer() if self._trainer else None
+ if not trainer:
raise RuntimeError("Cannot get row_sparse data for Parameter '%s'
when no " \
"Trainer is created with it."%self.name)
results = self._check_and_get(arr_list, ctx)
# fetch row sparse params from the trainer
- self._trainer._row_sparse_pull(self, results, row_id)
+ trainer._row_sparse_pull(self, results, row_id)
return results
def _load_init(self, data, ctx, cast_dtype=False, dtype_source='current'):
@@ -397,7 +402,11 @@ class Parameter(object):
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64',
ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
- self._trainer._row_sparse_pull(self, data, all_row_ids,
full_idx=True)
+ trainer = self._trainer() if self._trainer else None
+ if not trainer:
+ raise RuntimeError("Cannot reduce row_sparse data for
Parameter '%s' when no " \
+ "Trainer is created with it."%self.name)
+ trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
return data
def initialize(self, init=None, ctx=None,
default_init=initializer.Uniform(),
@@ -503,9 +512,10 @@ class Parameter(object):
return
# if update_on_kvstore, we need to make sure the copy stored in
kvstore is in sync
- if self._trainer and self._trainer._kv_initialized and
self._trainer._update_on_kvstore:
- if self not in self._trainer._params_to_init:
- self._trainer._reset_kvstore()
+ trainer = self._trainer() if self._trainer else None
+ if trainer and trainer._kv_initialized and trainer._update_on_kvstore:
+ if self not in trainer._params_to_init:
+ trainer._reset_kvstore()
for arr in self._check_and_get(self._data, list):
arr[:] = data
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index 98773b2..da13aff 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -99,6 +99,7 @@ def test_parameter_invalid_access():
assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)
@with_seed()
[email protected]("check_leak_ndarray")
def test_parameter_dict():
ctx = mx.cpu(1)
params0 = gluon.ParameterDict('net_')
@@ -3226,40 +3227,8 @@ def test_reqs_switching_training_inference():
mx.test_utils.assert_almost_equal(grad1, grad2)
-def test_no_memory_leak_in_gluon():
- # Collect all other garbage prior to this test. Otherwise the test may fail
- # due to unrelated memory leaks.
- gc.collect()
- gc_flags = gc.get_debug()
- gc.set_debug(gc.DEBUG_SAVEALL)
[email protected]("check_leak_ndarray")
+def test_no_memory_leak_in_gluon():
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
- del net
- gc.collect()
- gc.set_debug(gc_flags) # reset gc flags
-
- # Check for leaked NDArrays
- seen = set()
- def has_array(element):
- try:
- if element in seen:
- return False
- seen.add(element)
- except TypeError: # unhashable
- pass
-
- if isinstance(element, mx.nd._internal.NDArrayBase):
- return True
- elif hasattr(element, '__dict__'):
- return any(has_array(x) for x in vars(element))
- elif isinstance(element, dict):
- return any(has_array(x) for x in element.items())
- else:
- try:
- return any(has_array(x) for x in element)
- except (TypeError, KeyError):
- return False
-
- assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays
due to reference cycles'
- del gc.garbage[:]
diff --git a/tests/python/unittest/test_gluon_trainer.py
b/tests/python/unittest/test_gluon_trainer.py
index 892b2e3..874ab8c 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -37,7 +37,7 @@ def test_multi_trainer():
x.initialize()
# test set trainer
trainer0 = gluon.Trainer([x], 'sgd')
- assert(x._trainer is trainer0)
+ assert(x._trainer() is trainer0)
# test unset trainer
x._set_trainer(None)
assert(x._trainer is None)