This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 5079c35  [v1.x] Backport Improve environment variable handling in 
unittests (#18424) (#19173)
5079c35 is described below

commit 5079c35141610b725fda5ce2186de18151166743
Author: Dick Carter <[email protected]>
AuthorDate: Thu Sep 17 21:57:59 2020 -0700

    [v1.x] Backport Improve environment variable handling in unittests (#18424) 
(#19173)
    
    * Improve environment variable handling in unittests (#18424)
    
    * Add missing python functools import
    
    * Correct teardown import
---
 include/mxnet/c_api_test.h                  |  16 +++
 python/mxnet/test_utils.py                  | 113 +++++++++++------
 python/mxnet/util.py                        |  34 ++++-
 src/c_api/c_api_test.cc                     |  22 ++++
 tests/python/gpu/test_device.py             |  16 ++-
 tests/python/gpu/test_fusion.py             |  13 +-
 tests/python/gpu/test_gluon_gpu.py          |  18 +--
 tests/python/gpu/test_kvstore_gpu.py        |  10 +-
 tests/python/gpu/test_operator_gpu.py       |  42 ++++---
 tests/python/unittest/common.py             |  35 ++++--
 tests/python/unittest/test_autograd.py      |   6 +-
 tests/python/unittest/test_base.py          | 100 +++++++++++----
 tests/python/unittest/test_engine.py        |   4 +-
 tests/python/unittest/test_engine_import.py |  14 +--
 tests/python/unittest/test_executor.py      |  66 +++++-----
 tests/python/unittest/test_gluon.py         |  23 ++--
 tests/python/unittest/test_operator.py      | 184 +++++++++++++---------------
 tests/python/unittest/test_subgraph_op.py   |  28 ++---
 tests/python/unittest/test_symbol.py        |  68 ++++------
 19 files changed, 476 insertions(+), 336 deletions(-)

diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h
index b7ba0ce..df70798 100644
--- a/include/mxnet/c_api_test.h
+++ b/include/mxnet/c_api_test.h
@@ -75,6 +75,22 @@ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* 
prop_name);
 MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name);
 
 
+/*!
+ * \brief Get the value of an environment variable as seen by the backend.
+ * \param name The name of the environment variable
+ * \param value The returned value of the environment variable
+ */
+MXNET_DLL int MXGetEnv(const char* name,
+                       const char** value);
+
+/*!
+ * \brief Set the value of an environment variable from the backend.
+ * \param name The name of the environment variable
+ * \param value The desired value to set the environment variable `name`
+ */
+MXNET_DLL int MXSetEnv(const char* name,
+                       const char* value);
+
 #ifdef __cplusplus
 }
 #endif  // __cplusplus
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 3e06860..927d857 100755
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -24,6 +24,7 @@ import traceback
 import numbers
 import sys
 import os
+import platform
 import errno
 import logging
 import bz2
@@ -49,7 +50,7 @@ from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from .ndarray import array
 from .symbol import Symbol
 from .symbol.numpy import _Symbol as np_symbol
-from .util import use_np  # pylint: disable=unused-import
+from .util import use_np, getenv, setenv  # pylint: disable=unused-import
 from .runtime import Features
 from .numpy_extension import get_cuda_compute_capability
 
@@ -2035,27 +2036,6 @@ def get_bz2_data(data_dir, data_name, url, 
data_origin_name):
             bz_file.close()
         os.remove(data_origin_name)
 
-def set_env_var(key, val, default_val=""):
-    """Set environment variable
-
-    Parameters
-    ----------
-
-    key : str
-        Env var to set
-    val : str
-        New value assigned to the env var
-    default_val : str, optional
-        Default value returned if the env var doesn't exist
-
-    Returns
-    -------
-    str
-        The value of env var before it is set to the new value
-    """
-    prev_val = os.environ.get(key, default_val)
-    os.environ[key] = val
-    return prev_val
 
 def same_array(array1, array2):
     """Check whether two NDArrays sharing the same memory block
@@ -2080,9 +2060,11 @@ def same_array(array1, array2):
     array1[:] -= 1
     return same(array1.asnumpy(), array2.asnumpy())
 
+
 @contextmanager
 def discard_stderr():
-    """Discards error output of a routine if invoked as:
+    """
+    Discards error output of a routine if invoked as:
 
     with discard_stderr():
         ...
@@ -2471,22 +2453,79 @@ def same_symbol_structure(sym1, sym2):
     return True
 
 
-class EnvManager(object):
-    """Environment variable setter and unsetter via with idiom"""
-    def __init__(self, key, val):
-        self._key = key
-        self._next_val = val
-        self._prev_val = None
+@contextmanager
+def environment(*args):
+    """
+    Environment variable setter and unsetter via `with` idiom.
 
-    def __enter__(self):
-        self._prev_val = os.environ.get(self._key)
-        os.environ[self._key] = self._next_val
+    Takes a specification of env var names and desired values and adds those
+    settings to the environment in advance of running the body of the `with`
+    statement.  The original environment state is restored afterwards, even
+    if exceptions are raised in the `with` body.
 
-    def __exit__(self, ptype, value, trace):
-        if self._prev_val:
-            os.environ[self._key] = self._prev_val
-        else:
-            del os.environ[self._key]
+    Parameters
+    ----------
+    args:
+        if 2 args are passed:
+            name, desired_value strings of the single env var to update, or
+        if 1 arg is passed:
+            a dict of name:desired_value for env var's to update
+
+    """
+
+    # On Linux, env var changes made through python's os.environ are seen
+    # by the backend.  On Windows though, the C runtime gets a snapshot
+    # of the environment that cannot be altered by os.environ.  Here we
+    # check, using a wrapped version of the backend's getenv(), that
+    # the desired env var value is seen by the backend, and otherwise use
+    # a wrapped setenv() to establish that value in the backend.
+
+    # Also on Windows, a set env var can never have the value '', since
+    # the command 'set FOO= ' is used to unset the variable.  Perhaps
+    # as a result, the wrapped dmlc::GetEnv() routine returns the same
+    # value for unset variables and those set to ''.  As a result, we
+    # ignore discrepancy.
+    def validate_backend_setting(name, value, can_use_setenv=True):
+        backend_value = getenv(name)
+        if value == backend_value or \
+           value == '' and backend_value is None and platform.system() == 
'Windows':
+            return
+        if not can_use_setenv:
+            raise RuntimeError('Could not set env var {}={} within C 
Runtime'.format(name, value))
+        setenv(name, value)
+        validate_backend_setting(name, value, can_use_setenv=False)
+
+    # Core routine to alter environment from a dict of env_var_name, 
env_var_value pairs
+    def set_environ(env_var_dict):
+        for env_var_name, env_var_value in env_var_dict.items():
+            if env_var_value is None:
+                os.environ.pop(env_var_name, None)
+            else:
+                os.environ[env_var_name] = env_var_value
+            validate_backend_setting(env_var_name, env_var_value)
+
+    # Create env_var name:value dict from the two calling methods of this 
routine
+    if len(args) == 1 and isinstance(args[0], dict):
+        env_vars = args[0]
+    else:
+        assert len(args) == 2, 'Expecting one dict arg or two args: env var 
name and value'
+        env_vars = {args[0]: args[1]}
+
+    # Take a snapshot of the existing environment variable state
+    # for those variables to be changed.  get() return None for unset keys.
+    snapshot = {x: os.environ.get(x) for x in env_vars.keys()}
+
+    # Alter the environment per the env_vars dict
+    set_environ(env_vars)
+
+    # Now run the wrapped code
+    try:
+        yield
+    finally:
+        # the backend engines may still be referencing the changed env var 
state
+        mx.nd.waitall()
+        # reinstate original env_var state per the snapshot taken earlier
+        set_environ(snapshot)
 
 
 def collapse_sum_like(a, shape):
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index 54beeb5..aabd5fe 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -21,7 +21,7 @@ import functools
 import inspect
 import threading
 
-from .base import _LIB, check_call
+from .base import _LIB, check_call, c_str, py_str
 
 
 _np_ufunc_default_kwargs = {
@@ -816,3 +816,35 @@ def get_cuda_compute_capability(ctx):
         raise RuntimeError('cuDeviceComputeCapability failed with error code 
{}: {}'
                            .format(ret, error_str.value.decode()))
     return cc_major.value * 10 + cc_minor.value
+
+
+def getenv(name):
+    """Get the setting of an environment variable from the C Runtime.
+
+    Parameters
+    ----------
+    name : string type
+        The environment variable name
+
+    Returns
+    -------
+    value : string
+        The value of the environment variable, or None if not set
+    """
+    ret = ctypes.c_char_p()
+    check_call(_LIB.MXGetEnv(c_str(name), ctypes.byref(ret)))
+    return None if ret.value is None else py_str(ret.value)
+
+
+def setenv(name, value):
+    """Set an environment variable in the C Runtime.
+
+    Parameters
+    ----------
+    name : string type
+        The environment variable name
+    value : string type
+        The desired value to set the environment value to
+    """
+    passed_value = None if value is None else c_str(value)
+    check_call(_LIB.MXSetEnv(c_str(name), passed_value))
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
index de4fb7d..e84b0c0 100644
--- a/src/c_api/c_api_test.cc
+++ b/src/c_api/c_api_test.cc
@@ -106,3 +106,25 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* 
prop_name) {
   }
   API_END();
 }
+
+int MXGetEnv(const char* name,
+             const char** value) {
+  API_BEGIN();
+  *value = getenv(name);
+  API_END();
+}
+
+int MXSetEnv(const char* name,
+             const char* value) {
+  API_BEGIN();
+#ifdef _WIN32
+  auto value_arg = (value == nullptr) ? "" : value;
+  _putenv_s(name, value_arg);
+#else
+  if (value == nullptr)
+    unsetenv(name);
+  else
+    setenv(name, value, 1);
+#endif
+  API_END();
+}
diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py
index cd8145c..8a6fb3a 100644
--- a/tests/python/gpu/test_device.py
+++ b/tests/python/gpu/test_device.py
@@ -20,8 +20,7 @@ import numpy as np
 import unittest
 import os
 import logging
-
-from mxnet.test_utils import EnvManager
+from mxnet.test_utils import environment
 
 shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
 keys = [1,2,3,4,5,6,7]
@@ -51,16 +50,15 @@ def test_device_pushpull():
                 for x in range(n_gpus):
                     assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)
 
-    kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
-    kvstore_usetree_values = ['','1']
-    kvstore_usetree  = 'MXNET_KVSTORE_USETREE'
-    for _ in range(2):
+    kvstore_tree_array_bound_values = [None, '1']
+    kvstore_usetree_values = [None, '1']
+    for y in kvstore_tree_array_bound_values:
         for x in kvstore_usetree_values:
-            with EnvManager(kvstore_usetree, x):
+            with environment({'MXNET_KVSTORE_USETREE': x,
+                              'MXNET_KVSTORE_TREE_ARRAY_BOUND': y}):
                 check_dense_pushpull('local')
                 check_dense_pushpull('device')
-        os.environ[kvstore_tree_array_bound] = '1'
-    del os.environ[kvstore_tree_array_bound]
+
 
 if __name__ == '__main__':
     test_device_pushpull()
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 61fba10..1bbf598 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -15,15 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import sys
 import os
 import random
+import itertools
 import mxnet as mx
 import numpy as np
 from mxnet.test_utils import *
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import with_seed
+from common import setup_module, teardown, with_seed
 
 def check_fused_symbol(sym, **kwargs):
     inputs = sym.list_inputs()
@@ -43,10 +45,10 @@ def check_fused_symbol(sym, **kwargs):
         data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
         for grad_req in ['write', 'add']:
             type_dict = {inp : dtype for inp in inputs}
-            os.environ["MXNET_USE_FUSION"] = "0"
-            orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, 
type_dict=type_dict, **shapes)
-            os.environ["MXNET_USE_FUSION"] = "1"
-            fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, 
type_dict=type_dict, **shapes)
+            with environment('MXNET_USE_FUSION', '0'):
+                orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, 
type_dict=type_dict, **shapes)
+            with environment('MXNET_USE_FUSION', '1'):
+                fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, 
type_dict=type_dict, **shapes)
             fwd_orig = orig_exec.forward(is_train=True, **data)
             out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
             orig_exec.backward(out_grads=out_grads)
@@ -227,6 +229,7 @@ def check_other_ops():
     arr2 = mx.random.uniform(shape=(2,2,2,3))
     check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], 
rhs_axes=[0]), a=arr1, b=arr2)
 
+
 def check_leakyrelu_ops():
     a = mx.sym.Variable('a')
     b = mx.sym.Variable('b')
diff --git a/tests/python/gpu/test_gluon_gpu.py 
b/tests/python/gpu/test_gluon_gpu.py
index 52280bf..60a90c9 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -22,7 +22,7 @@ import tempfile
 import time
 import mxnet as mx
 import multiprocessing as mp
-from mxnet.test_utils import check_consistency, set_default_context, 
assert_almost_equal, rand_ndarray
+from mxnet.test_utils import check_consistency, set_default_context, 
assert_almost_equal, rand_ndarray, environment
 import mxnet.ndarray as nd
 import numpy as np
 import math
@@ -555,9 +555,9 @@ def _test_bulking(test_bulking_func):
         time_per_iteration = mp.Manager().Value('d', 0.0)
 
         if not run_in_spawned_process(test_bulking_func,
-                                      
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0],
-                                       
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1],
-                                       'MXNET_EXEC_BULK_EXEC_TRAIN': 
seg_sizes[2]},
+                                      
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': str(seg_sizes[0]),
+                                       
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': str(seg_sizes[1]),
+                                       'MXNET_EXEC_BULK_EXEC_TRAIN': 
str(seg_sizes[2])},
                                       time_per_iteration):
             # skip test since the python version can't run it properly.  
Warning msg was logged.
             return
@@ -631,15 +631,17 @@ def test_gemms_true_fp16():
     net.cast('float16')
     net.initialize(ctx=ctx)
     net.weight.set_data(weights)
-    ref_results = net(input)
 
-    os.environ["MXNET_FC_TRUE_FP16"] = "1"
-    results_trueFP16 = net(input)
+    with environment('MXNET_FC_TRUE_FP16', '0'):
+      ref_results = net(input)
+
+    with environment('MXNET_FC_TRUE_FP16', '1'):
+      results_trueFP16 = net(input)
+
     atol = 1e-2
     rtol = 1e-2
     assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(),
                         atol=atol, rtol=rtol)
-    os.environ["MXNET_FC_TRUE_FP16"] = "0"
 
 
 if __name__ == '__main__':
diff --git a/tests/python/gpu/test_kvstore_gpu.py 
b/tests/python/gpu/test_kvstore_gpu.py
index 1dddc58..8473dd3 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -21,7 +21,7 @@ import os
 import mxnet as mx
 import numpy as np
 import unittest
-from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
+from mxnet.test_utils import assert_almost_equal, default_context, environment
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown
@@ -97,11 +97,11 @@ def test_rsp_push_pull():
         check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, 
use_slice=True)
         check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, 
use_slice=True)
 
-    envs = ["","1"]
-    key  = "MXNET_KVSTORE_USETREE"
+    envs = [None, '1']
+    key  = 'MXNET_KVSTORE_USETREE'
     for val in envs:
-        with EnvManager(key, val):
-            if val is "1":
+        with environment(key, val):
+            if val is '1':
                 sparse_pull = False
             else:
                 sparse_pull = True
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index bcf906a..5fee473 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -28,6 +28,8 @@ import scipy.sparse as sps
 import mxnet.ndarray.sparse as mxsps
 import itertools
 from mxnet.test_utils import check_consistency, set_default_context, 
assert_almost_equal, assert_allclose
+from mxnet.test_utils import check_symbolic_forward, check_symbolic_backward, 
discard_stderr
+from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, 
same, environment
 from mxnet.base import MXNetError
 from mxnet import autograd
 
@@ -755,12 +757,12 @@ def _conv_with_num_streams(seed):
 @unittest.skip("skipping for now due to severe flakiness")
 @with_seed()
 def test_convolution_multiple_streams():
-    for num_streams in [1, 2]:
+    for num_streams in ['1', '2']:
         for engine in ['NaiveEngine', 'ThreadedEngine', 
'ThreadedEnginePerDevice']:
-            print("Starting engine %s with %d streams." % (engine, 
num_streams), file=sys.stderr)
+            print('Starting engine {} with {} streams.'.format(engine, 
num_streams), file=sys.stderr)
             run_in_spawned_process(_conv_with_num_streams,
                 {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 
'MXNET_ENGINE_TYPE' : engine})
-            print("Finished engine %s with %d streams." % (engine, 
num_streams), file=sys.stderr)
+            print('Finished engine {} with {} streams.'.format(engine, 
num_streams), file=sys.stderr)
 
 
 # This test is designed to expose an issue with cudnn v7.1.4 algo find() when 
invoked with large c.
@@ -2229,22 +2231,22 @@ def test_multi_proposal_op():
 
 # The following 2 functions launch 0-thread kernels, an error that should be 
caught and signaled.
 def kernel_error_check_imperative():
-    os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
-    with mx.np_shape(active=True):
-        a = mx.nd.array([1,2,3],ctx=mx.gpu(0))
-        b = mx.nd.array([],ctx=mx.gpu(0))
-        c = (a / b).asnumpy()
+    with environment('MXNET_ENGINE_TYPE', 'NaiveEngine'):
+        with mx.np_shape(active=True):
+            a = mx.nd.array([1,2,3],ctx=mx.gpu(0))
+            b = mx.nd.array([],ctx=mx.gpu(0))
+            c = (a / b).asnumpy()
 
 def kernel_error_check_symbolic():
-    os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
-    with mx.np_shape(active=True):
-        a = mx.sym.Variable('a')
-        b = mx.sym.Variable('b')
-        c = a / b
-        f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)),
-                                'b':mx.nd.array([],ctx=mx.gpu(0))})
-        f.forward()
-        g = f.outputs[0].asnumpy()
+    with environment('MXNET_ENGINE_TYPE', 'NaiveEngine'):
+        with mx.np_shape(active=True):
+            a = mx.sym.Variable('a')
+            b = mx.sym.Variable('b')
+            c = a / b
+            f = c.bind(mx.gpu(0), {'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)),
+                                   'b':mx.nd.array([],ctx=mx.gpu(0))})
+            f.forward()
+            g = f.outputs[0].asnumpy()
 
 def test_kernel_error_checking():
     # Running tests that may throw exceptions out of worker threads will stop 
CI testing
@@ -2440,9 +2442,9 @@ def test_bulking():
         # Create shared variable to return measured time from test process
         time_per_iteration = mp.Manager().Value('d', 0.0)
         if not run_in_spawned_process(_test_bulking_in_process,
-                                      
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0],
-                                       
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1],
-                                       'MXNET_EXEC_BULK_EXEC_TRAIN' : 
seg_sizes[2]},
+                                      
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : str(seg_sizes[0]),
+                                       
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : str(seg_sizes[1]),
+                                       'MXNET_EXEC_BULK_EXEC_TRAIN' : 
str(seg_sizes[2])},
                                       time_per_iteration):
             # skip test since the python version can't run it properly.  
Warning msg was logged.
             return
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index 8e4e2e3..cbddf0a 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -16,13 +16,14 @@
 # under the License.
 
 from __future__ import print_function
-import sys, os, logging
+import sys, os, logging, functools
 import multiprocessing as mp
 import mxnet as mx
 import numpy as np
 import random
 import shutil
 from mxnet.base import MXNetError
+from mxnet.test_utils import environment
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.append(os.path.join(curr_path, '../common/'))
 sys.path.insert(0, os.path.join(curr_path, '../../../python'))
@@ -208,15 +209,17 @@ def with_seed(seed=None):
                 logger = default_logger()
                 # 'nosetests --logging-level=DEBUG' shows this msg even with 
an ensuing core dump.
                 test_count_msg = '{} of {}: '.format(i+1,test_count) if 
test_count > 1 else ''
-                test_msg = ('{}Setting test np/mx/python random seeds, use 
MXNET_TEST_SEED={}'
-                            ' to reproduce.').format(test_count_msg, 
this_test_seed)
-                logger.log(log_level, test_msg)
+                pre_test_msg = ('{}Setting test np/mx/python random seeds, use 
MXNET_TEST_SEED={}'
+                                ' to reproduce.').format(test_count_msg, 
this_test_seed)
+                on_err_test_msg = ('{}Error seen with seeded test, use 
MXNET_TEST_SEED={}'
+                                ' to reproduce.').format(test_count_msg, 
this_test_seed)
+                logger.log(log_level, pre_test_msg)
                 try:
                     orig_test(*args, **kwargs)
                 except:
                     # With exceptions, repeat test_msg at WARNING level to be 
sure it's seen.
                     if log_level < logging.WARNING:
-                        logger.warning(test_msg)
+                        logger.warning(on_err_test_msg)
                     raise
                 finally:
                     # Provide test-isolation for any test having this decorator
@@ -329,6 +332,20 @@ def with_post_test_cleanup():
             finally:
                 mx.nd.waitall()
                 mx.cpu().empty_cache()
+
+
+def with_environment(*args_):
+    """
+    Helper function that takes a dictionary of environment variables and their
+    desired settings and changes the environment in advance of running the
+    decorated code.  The original environment state is reinstated afterwards,
+    even if exceptions are raised.
+    """
+    def test_helper(orig_test):
+        @functools.wraps(orig_test)
+        def test_new(*args, **kwargs):
+            with environment(*args_):
+                orig_test(*args, **kwargs)
         return test_new
     return test_helper
 
@@ -363,16 +380,10 @@ def run_in_spawned_process(func, env, *args):
         return False
     else:
         seed = np.random.randint(0,1024*1024*1024)
-        orig_environ = os.environ.copy()
-        try:
-            for (key, value) in env.items():
-                os.environ[key] = str(value)
+        with environment(env):
             # Prepend seed as first arg
             p = mpctx.Process(target=func, args=(seed,)+args)
             p.start()
             p.join()
             assert p.exitcode == 0, "Non-zero exit code %d from %s()." % 
(p.exitcode, func.__name__)
-        finally:
-            os.environ.clear()
-            os.environ.update(orig_environ)
     return True
diff --git a/tests/python/unittest/test_autograd.py 
b/tests/python/unittest/test_autograd.py
index 61955f0..caff307 100644
--- a/tests/python/unittest/test_autograd.py
+++ b/tests/python/unittest/test_autograd.py
@@ -21,7 +21,7 @@ from mxnet.ndarray import zeros_like
 from mxnet.autograd import *
 from mxnet.test_utils import *
 from common import setup_module, with_seed, teardown
-from mxnet.test_utils import EnvManager
+from mxnet.test_utils import environment
 
 
 def grad_and_loss(func, argnum=None):
@@ -121,7 +121,7 @@ def test_unary_func():
         autograd_assert(x, func=f_square, grad_func=f_square_grad)
     uniform = nd.uniform(shape=(4, 5))
     stypes = ['default', 'row_sparse', 'csr']
-    with EnvManager('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
+    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
         for stype in stypes:
             check_unary_func(uniform.tostype(stype))
 
@@ -140,7 +140,7 @@ def test_binary_func():
     uniform_x = nd.uniform(shape=(4, 5))
     uniform_y = nd.uniform(shape=(4, 5))
     stypes = ['default', 'row_sparse', 'csr']
-    with EnvManager('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
+    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
         for stype_x in stypes:
             for stype_y in stypes:
                 x = uniform_x.tostype(stype_x)
diff --git a/tests/python/unittest/test_base.py 
b/tests/python/unittest/test_base.py
index 3189729..5744dcc 100644
--- a/tests/python/unittest/test_base.py
+++ b/tests/python/unittest/test_base.py
@@ -16,35 +16,93 @@
 # under the License.
 
 import mxnet as mx
+from numpy.testing import assert_equal
 from mxnet.base import data_dir
 from nose.tools import *
+from mxnet.test_utils import environment
+from mxnet.util import getenv
+from common import setup_module, teardown, with_environment
 import os
-import unittest
 import logging
 import os.path as op
 import platform
 
-class MXNetDataDirTest(unittest.TestCase):
-    def setUp(self):
-        self.mxnet_data_dir = os.environ.get('MXNET_HOME')
-        if 'MXNET_HOME' in os.environ:
-            del os.environ['MXNET_HOME']
+def test_environment():
+    name1 = 'MXNET_TEST_ENV_VAR_1'
+    name2 = 'MXNET_TEST_ENV_VAR_2'
 
-    def tearDown(self):
-        if self.mxnet_data_dir:
-            os.environ['MXNET_HOME'] = self.mxnet_data_dir
-        else:
-            if 'MXNET_HOME' in os.environ:
-                del os.environ['MXNET_HOME']
+    # Test that a variable can be set in the python and backend environment
+    with environment(name1, '42'):
+        assert_equal(os.environ.get(name1), '42')
+        assert_equal(getenv(name1), '42')
+
+    # Test dict form of invocation
+    env_var_dict = {name1: '1', name2: '2'}
+    with environment(env_var_dict):
+        for key, value in env_var_dict.items():
+            assert_equal(os.environ.get(key), value)
+            assert_equal(getenv(key), value)
+
+    # Further testing in 'test_with_environment()'
+
+@with_environment({'MXNET_TEST_ENV_VAR_1': '10', 'MXNET_TEST_ENV_VAR_2': None})
+def test_with_environment():
+    name1 = 'MXNET_TEST_ENV_VAR_1'
+    name2 = 'MXNET_TEST_ENV_VAR_2'
+    def check_background_values():
+        assert_equal(os.environ.get(name1), '10')
+        assert_equal(getenv(name1), '10')
+        assert_equal(os.environ.get(name2), None)
+        assert_equal(getenv(name2), None)
+
+    check_background_values()
 
-    def test_data_dir(self,):
-        prev_data_dir = data_dir()
-        system = platform.system()
-        if system != 'Windows':
-            self.assertEqual(data_dir(), op.join(op.expanduser('~'), '.mxnet'))
-        os.environ['MXNET_HOME'] = '/tmp/mxnet_data'
-        self.assertEqual(data_dir(), '/tmp/mxnet_data')
-        del os.environ['MXNET_HOME']
-        self.assertEqual(data_dir(), prev_data_dir)
+    # This completes the testing of with_environment(), but since we have
+    # an environment with a couple of known settings, lets use it to test if
+    # 'with environment()' properly restores to these settings in all cases.
 
+    class OnPurposeError(Exception):
+        """A class for exceptions thrown by this test"""
+        pass
 
+    # Enter an environment with one variable set and check it appears
+    # to both python and the backend.  Then, outside the 'with' block,
+    # make sure the background environment is seen, regardless of whether
+    # the 'with' block raised an exception.
+    def test_one_var(name, value, raise_exception=False):
+        try:
+            with environment(name, value):
+                assert_equal(os.environ.get(name), value)
+                assert_equal(getenv(name), value)
+                if raise_exception:
+                    raise OnPurposeError
+        except OnPurposeError:
+            pass
+        finally:
+            check_background_values()
+
+    # Test various combinations of set and unset env vars.
+    # Test that the background setting is restored in the presense of 
exceptions.
+    for raise_exception in [False, True]:
+        # name1 is initially set in the environment
+        test_one_var(name1, '42', raise_exception)
+        test_one_var(name1, None, raise_exception)
+        # name2 is initially not set in the environment
+        test_one_var(name2, '42', raise_exception)
+        test_one_var(name2, None, raise_exception)
+
+
+def test_data_dir():
+    prev_data_dir = data_dir()
+    system = platform.system()
+    # Test that data_dir() returns the proper default value when MXNET_HOME is 
not set
+    with environment('MXNET_HOME', None):
+        if system == 'Windows':
+            assert_equal(data_dir(), op.join(os.environ.get('APPDATA'), 
'mxnet'))
+        else:
+            assert_equal(data_dir(), op.join(op.expanduser('~'), '.mxnet'))
+    # Test that data_dir() responds to an explicit setting of MXNET_HOME
+    with environment('MXNET_HOME', '/tmp/mxnet_data'):
+        assert_equal(data_dir(), '/tmp/mxnet_data')
+    # Test that this test has not disturbed the MXNET_HOME value existing 
before the test
+    assert_equal(data_dir(), prev_data_dir)
diff --git a/tests/python/unittest/test_engine.py 
b/tests/python/unittest/test_engine.py
index 61d94dd..ea02551 100644
--- a/tests/python/unittest/test_engine.py
+++ b/tests/python/unittest/test_engine.py
@@ -19,7 +19,7 @@ import nose
 import mxnet as mx
 import os
 import unittest
-from mxnet.test_utils import EnvManager
+from mxnet.test_utils import environment
 
 def test_bulk():
     with mx.engine.bulk(10):
@@ -42,7 +42,7 @@ def test_engine_openmp_after_fork():
     With GOMP the child always has the same number when calling 
omp_get_max_threads, with LLVM OMP
     the child respects the number of max threads set in the parent.
     """
-    with EnvManager('OMP_NUM_THREADS', '42'):
+    with environment('OMP_NUM_THREADS', '42'):
         r, w = os.pipe()
         pid = os.fork()
         if pid:
diff --git a/tests/python/unittest/test_engine_import.py 
b/tests/python/unittest/test_engine_import.py
index 303f3ce..ed56531 100644
--- a/tests/python/unittest/test_engine_import.py
+++ b/tests/python/unittest/test_engine_import.py
@@ -16,6 +16,8 @@
 # under the License.
 
 import os
+from mxnet.test_utils import environment
+import unittest
 
 try:
     reload         # Python 2
@@ -23,17 +25,15 @@ except NameError:  # Python 3
     from importlib import reload
 
 
[email protected]('test needs improving, current use of reload(mxnet) is 
ineffective')
 def test_engine_import():
     import mxnet
-        
-    engine_types = ['', 'NaiveEngine', 'ThreadedEngine', 
'ThreadedEnginePerDevice']
+    # Temporarily add an illegal entry (that is not caught) to show how the 
test needs improving
+    engine_types = [None, 'NaiveEngine', 'ThreadedEngine', 
'ThreadedEnginePerDevice', 'BogusEngine']
 
     for type in engine_types:
-        if type:
-            os.environ['MXNET_ENGINE_TYPE'] = type
-        else:
-            os.environ.pop('MXNET_ENGINE_TYPE', None)
-        reload(mxnet)
+        with environment('MXNET_ENGINE_TYPE', type):
+            reload(mxnet)
 
 
 if __name__ == '__main__':
diff --git a/tests/python/unittest/test_executor.py 
b/tests/python/unittest/test_executor.py
index 2bc696f..29fda3b 100644
--- a/tests/python/unittest/test_executor.py
+++ b/tests/python/unittest/test_executor.py
@@ -18,7 +18,7 @@
 import numpy as np
 import mxnet as mx
 from common import setup_module, with_seed, teardown
-from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import assert_almost_equal, environment
 
 
 def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None):
@@ -74,42 +74,34 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, 
lshape=None, rshape=None):
 
 @with_seed()
 def test_bind():
-    def check_bind(disable_bulk_exec):
-        if disable_bulk_exec:
-            prev_bulk_inf_val = 
mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_INFERENCE", "0", "1")
-            prev_bulk_train_val = 
mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_TRAIN", "0", "1")
-
-        nrepeat = 10
-        maxdim = 4
-        for repeat in range(nrepeat):
-            for dim in range(1, maxdim):
-                check_bind_with_uniform(lambda x, y: x + y,
-                                        lambda g, x, y: (g, g),
-                                        dim)
-                check_bind_with_uniform(lambda x, y: x - y,
-                                        lambda g, x, y: (g, -g),
-                                        dim)
-                check_bind_with_uniform(lambda x, y: x * y,
-                                        lambda g, x, y: (y * g, x * g),
-                                        dim)
-                check_bind_with_uniform(lambda x, y: x / y,
-                                        lambda g, x, y: (g / y, -x * g/ 
(y**2)),
-                                        dim)
-
-                check_bind_with_uniform(lambda x, y: np.maximum(x, y),
-                                        lambda g, x, y: (g * (x>=y), g * 
(y>x)),
-                                        dim,
-                                        sf=mx.symbol.maximum)
-                check_bind_with_uniform(lambda x, y: np.minimum(x, y),
-                                        lambda g, x, y: (g * (x<=y), g * 
(y<x)),
-                                        dim,
-                                        sf=mx.symbol.minimum)
-        if disable_bulk_exec:
-           mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_INFERENCE", 
prev_bulk_inf_val)
-           mx.test_utils.set_env_var("MXNET_EXEC_BULK_EXEC_TRAIN", 
prev_bulk_train_val)
-
-    check_bind(True)
-    check_bind(False)
+    for enable_bulking in ['0', '1']:
+        with environment({'MXNET_EXEC_BULK_EXEC_INFERENCE': enable_bulking,
+                          'MXNET_EXEC_BULK_EXEC_TRAIN': enable_bulking}):
+            nrepeat = 10
+            maxdim = 4
+            for repeat in range(nrepeat):
+                for dim in range(1, maxdim):
+                    check_bind_with_uniform(lambda x, y: x + y,
+                                            lambda g, x, y: (g, g),
+                                            dim)
+                    check_bind_with_uniform(lambda x, y: x - y,
+                                            lambda g, x, y: (g, -g),
+                                            dim)
+                    check_bind_with_uniform(lambda x, y: x * y,
+                                            lambda g, x, y: (y * g, x * g),
+                                            dim)
+                    check_bind_with_uniform(lambda x, y: x / y,
+                                            lambda g, x, y: (g / y, -x * g/ 
(y**2)),
+                                            dim)
+
+                    check_bind_with_uniform(lambda x, y: np.maximum(x, y),
+                                            lambda g, x, y: (g * (x>=y), g * 
(y>x)),
+                                            dim,
+                                            sf=mx.symbol.maximum)
+                    check_bind_with_uniform(lambda x, y: np.minimum(x, y),
+                                            lambda g, x, y: (g * (x<=y), g * 
(y<x)),
+                                            dim,
+                                            sf=mx.symbol.minimum)
 
 
 # @roywei: Removing fixed seed as flakiness in this test is fixed
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 6129c28..98b606d 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -29,7 +29,7 @@ from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from mxnet.test_utils import use_np
 import mxnet.numpy as _mx_np
 from common import (setup_module, with_seed, assertRaises, teardown,
-                    assert_raises_cudnn_not_satisfied)
+                    assert_raises_cudnn_not_satisfied, environment)
 import numpy as np
 from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
@@ -1694,21 +1694,12 @@ def test_zero_grad():
         for type in [testedTypes] + testedTypes:
             _test_multi_reset(np.random.randint(1, 50), type, ctx)
 
-    # Saving value of environment variable, if it was defined
-    envVarKey = 'MXNET_STORAGE_FALLBACK_LOG_VERBOSE'
-    envVarValue = os.environ[envVarKey] if envVarKey in os.environ else None
-    # Changing value of environment variable
-    os.environ[envVarKey] = '0'
-    for type in ['float16', 'float32', 'float64']:
-        for embType in ['float32', 'float64']:
-            for sparse in [True, False]:
-                _test_grad_reset(ctx, dtype=type, sparse=sparse, 
embeddingType=embType)
-
-    # Remove or restore the value of environment variable
-    if envVarValue is None:
-        del os.environ[envVarKey]
-    else:
-        os.environ[envVarKey] = envVarValue
+    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
+        for type in ['float16', 'float32', 'float64']:
+            for embType in ['float32', 'float64']:
+                for sparse in [True, False]:
+                    _test_grad_reset(ctx, dtype=type, sparse=sparse, 
embeddingType=embType)
+
 
 def check_hybrid_static_memory(**kwargs):
     x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index edb3e6a..c4dd68f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -30,7 +30,7 @@ from mxnet.test_utils import *
 from mxnet.operator import *
 from mxnet.base import py_str, MXNetError, _as_list
 from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_not_satisfied, assert_raises_cuda_not_satisfied, 
assertRaises
-from common import run_in_spawned_process
+from common import run_in_spawned_process, with_environment
 from nose.tools import assert_raises, ok_
 import unittest
 import os
@@ -3918,83 +3918,77 @@ def test_norm():
                 np.int32: np.int32, np.int64: np.int64}
     dtype_to_str = {np.float16: 'float16', np.float32: 'float32', np.float64: 
'float64',
                     np.int32: 'int32', np.int64: 'int64'}
-    is_windows = sys.platform.startswith('win')
-    for enforce_safe_acc in ["1", "0"]:
-        if is_windows:
-            if enforce_safe_acc == "0":
-                break
-            enforce_safe_acc = "0" if "MXNET_SAFE_ACCUMULATION" not in 
os.environ else os.environ["MXNET_SAFE_ACCUMULATION"]
-        else:
-            os.environ["MXNET_SAFE_ACCUMULATION"] = enforce_safe_acc
-        for order in [1, 2]:
-            for dtype in [np.float16, np.float32, np.float64]:
-                for i in range(in_data_dim):
-                    for out_dtype in ['float32', 'float64']:
-                        backward_dtype = np.float32 if out_dtype == 'float32' 
else np.float64
-                        accumulation_type = acc_type[dtype]
-                        if enforce_safe_acc == "0":
-                            backward_dtype = dtype
-                            out_dtype = dtype_to_str[dtype]
-                            accumulation_type = dtype
-                        skip_backward = 'int' in out_dtype
-                        in_data = np.random.uniform(-1, 1, 
in_shape).astype(accumulation_type)
-                        in_data[abs(in_data) < epsilon] = 2 * epsilon
-                        norm_sym = mx.symbol.norm(data=data, ord=order, 
axis=i, out_dtype=out_dtype, keepdims=True)
-                        npy_out = l1norm(in_data, i) if order is 1 else 
l2norm(in_data, i)
-                        npy_out_backward = np.sign(in_data) if order is 1 else 
in_data/npy_out
-                        check_symbolic_forward(norm_sym, 
[in_data.astype(dtype)], [npy_out.astype(out_dtype)],
-                                               rtol=1e-2 if dtype == 
np.float16 else 1e-3,
-                                               atol=1e-4 if dtype == 
np.float16 else 1e-5, ctx=ctx, dtype=dtype)
-                        if dtype is not np.float16 and not skip_backward:
-                            check_symbolic_backward(norm_sym, 
[in_data.astype(dtype)],
-                                                    
[np.ones(npy_out.shape).astype(out_dtype)],
-                                                    [npy_out_backward], 
rtol=1e-3, atol=1e-5, ctx=ctx,
-                                                    dtype=backward_dtype)
-                        # Disable numeric gradient 
https://github.com/apache/incubator-mxnet/issues/11509
-                        # check gradient
-                        if dtype is not np.float16 and not skip_backward:
-                            check_numeric_gradient(norm_sym, [in_data], 
numeric_eps=epsilon,
-                                                   rtol=1e-1, atol=1e-3, 
dtype=backward_dtype)
-                        if i < in_data_dim-1:
-                            norm_sym = mx.symbol.norm(data=data, ord=order, 
axis=(i, i+1), keepdims=True)
-                            npy_out = l1norm(in_data, (i, i+1)) if order is 1 
else l2norm(in_data, (i, i+1))
+    for enforce_safe_acc in ['1', '0']:
+        with environment('MXNET_SAFE_ACCUMULATION', enforce_safe_acc):
+            for order in [1, 2]:
+                for dtype in [np.float16, np.float32, np.float64]:
+                    for i in range(in_data_dim):
+                        for out_dtype in ['float32', 'float64']:
+                            backward_dtype = np.float32 if out_dtype == 
'float32' else np.float64
+                            accumulation_type = acc_type[dtype]
+                            if enforce_safe_acc == "0":
+                                backward_dtype = dtype
+                                out_dtype = dtype_to_str[dtype]
+                                accumulation_type = dtype
+                            skip_backward = 'int' in out_dtype
+                            in_data = np.random.uniform(-1, 1, 
in_shape).astype(accumulation_type)
+                            in_data[abs(in_data) < epsilon] = 2 * epsilon
+                            norm_sym = mx.symbol.norm(data=data, ord=order, 
axis=i, out_dtype=out_dtype, keepdims=True)
+                            npy_out = l1norm(in_data, i) if order is 1 else 
l2norm(in_data, i)
                             npy_out_backward = np.sign(in_data) if order is 1 
else in_data/npy_out
-                            check_symbolic_forward(norm_sym, [in_data], 
[npy_out.astype(dtype)],
-                                                   rtol=1e-2 if dtype is 
np.float16 else 1e-3,
-                                                   atol=1e-4 if dtype is 
np.float16 else 1e-5, ctx=ctx)
+                            check_symbolic_forward(norm_sym, 
[in_data.astype(dtype)], [npy_out.astype(out_dtype)],
+                                                   rtol=1e-2 if dtype == 
np.float16 else 1e-3,
+                                                   atol=1e-4 if dtype == 
np.float16 else 1e-5, ctx=ctx, dtype=dtype)
                             if dtype is not np.float16 and not skip_backward:
-                                check_symbolic_backward(norm_sym, [in_data],
+                                check_symbolic_backward(norm_sym, 
[in_data.astype(dtype)],
                                                         
[np.ones(npy_out.shape).astype(out_dtype)],
-                                                        
[npy_out_backward.astype(out_dtype)],
-                                                        rtol=1e-3, atol=1e-5, 
ctx=ctx, dtype=backward_dtype)
-                            # check gradient
-                            if dtype is not np.float16 and not skip_backward:
-                                check_numeric_gradient(norm_sym, [in_data], 
numeric_eps=epsilon,
-                                                       rtol=1e-1, atol=1e-3, 
dtype=backward_dtype)
+                                                        [npy_out_backward], 
rtol=1e-3, atol=1e-5, ctx=ctx,
+                                                        dtype=backward_dtype)
+                                # Disable numeric gradient 
https://github.com/apache/incubator-mxnet/issues/11509
+                                # check gradient
+                                if dtype is not np.float16 and not 
skip_backward:
+                                    check_numeric_gradient(norm_sym, 
[in_data], numeric_eps=epsilon,
+                                                   rtol=1e-1, atol=1e-3, 
dtype=backward_dtype)
+                            if i < in_data_dim-1:
+                                norm_sym = mx.symbol.norm(data=data, 
ord=order, axis=(i, i+1), keepdims=True)
+                                npy_out = l1norm(in_data, (i, i+1)) if order 
is 1 else l2norm(in_data, (i, i+1))
+                                npy_out_backward = np.sign(in_data) if order 
is 1 else in_data/npy_out
+                                check_symbolic_forward(norm_sym, [in_data], 
[npy_out.astype(dtype)],
+                                                       rtol=1e-2 if dtype is 
np.float16 else 1e-3,
+                                                       atol=1e-4 if dtype is 
np.float16 else 1e-5, ctx=ctx)
+                                if dtype is not np.float16 and not 
skip_backward:
+                                    check_symbolic_backward(norm_sym, 
[in_data],
+                                                            
[np.ones(npy_out.shape).astype(out_dtype)],
+                                                            
[npy_out_backward.astype(out_dtype)],
+                                                            rtol=1e-3, 
atol=1e-5, ctx=ctx, dtype=backward_dtype)
+                                # check gradient
+                                if dtype is not np.float16 and not 
skip_backward:
+                                    check_numeric_gradient(norm_sym, 
[in_data], numeric_eps=epsilon,
+                                                           rtol=1e-1, 
atol=1e-3, dtype=backward_dtype)
 
 
 def test_layer_norm():
     for enforce_safe_acc in ["1", "0"]:
-        os.environ["MXNET_SAFE_ACCUMULATION"] = enforce_safe_acc
-        for dtype, forward_check_eps, backward_check_eps in zip([np.float16, 
np.float32, np.float64],
-                                                                [1E-2, 1E-3, 
1E-4],
-                                                                [1E-2, 1E-3, 
1E-4]):
-            if dtype != np.float16:
-                in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10), (128 
* 32, 512)], [True, True, False]
-            else:
-                in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10)], 
[True, True]  # large input + fp16 does not pass the forward check
-            for in_shape, finite_grad_check in zip(in_shape_l, 
finite_grad_check_l):
-                for axis in range(-len(in_shape), len(in_shape)):
-                    for eps in [1E-2, 1E-3]:
-                        if dtype == np.float16:
-                            npy_grad_check = False
-                        else:
-                            npy_grad_check = True
-                        check_layer_normalization(in_shape, axis, eps, 
dtype=dtype,
-                                                  
forward_check_eps=forward_check_eps,
-                                                  
backward_check_eps=backward_check_eps,
-                                                  
npy_grad_check=npy_grad_check,
-                                                  
finite_grad_check=finite_grad_check)
+        with environment('MXNET_SAFE_ACCUMULATION', enforce_safe_acc):
+            for dtype, forward_check_eps, backward_check_eps in 
zip([np.float16, np.float32, np.float64],
+                                                                    [1E-2, 
1E-3, 1E-4],
+                                                                    [1E-2, 
1E-3, 1E-4]):
+                if dtype != np.float16:
+                    in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10), 
(128 * 32, 512)], [True, True, False]
+                else:
+                    in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10)], 
[True, True]  # large input + fp16 does not pass the forward check
+                for in_shape, finite_grad_check in zip(in_shape_l, 
finite_grad_check_l):
+                    for axis in range(-len(in_shape), len(in_shape)):
+                        for eps in [1E-2, 1E-3]:
+                            if dtype == np.float16:
+                                npy_grad_check = False
+                            else:
+                                npy_grad_check = True
+                            check_layer_normalization(in_shape, axis, eps, 
dtype=dtype,
+                                                      
forward_check_eps=forward_check_eps,
+                                                      
backward_check_eps=backward_check_eps,
+                                                      
npy_grad_check=npy_grad_check,
+                                                      
finite_grad_check=finite_grad_check)
 
 
 # Numpy Implementation of Sequence Ops
@@ -5401,6 +5395,7 @@ def test_softmax_with_large_inputs():
     softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0]))
 
 @with_seed()
+@with_environment('MXNET_SAFE_ACCUMULATION', '1')
 def test_softmax_dtype():
     def check_dtypes_almost_equal(op_name,
                                   atol, rtol,
@@ -5420,27 +5415,22 @@ def test_softmax_dtype():
         ref_softmax.backward()
         assert_almost_equal(dtype_input.grad, ref_input.grad, rtol=grad_rtol, 
atol=grad_atol)
 
-    import sys
-    is_windows = sys.platform.startswith('win')
-    enforce_safe_acc = os.environ.get("MXNET_SAFE_ACCUMULATION", "0")
-    if not is_windows or enforce_safe_acc == "1":
-        os.environ["MXNET_SAFE_ACCUMULATION"] = "1"
-        check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 
'float16', 'float32')
-        check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 
'float16', 'float32', 'float32')
-        check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 
'float32', 'float64')
-        check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 
'float32', 'float64', 'float64')
-        check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 
'float16', 'float32')
-        check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 
'float16', 'float32', 'float32')
-        check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 
'float32', 'float64')
-        check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 
'float32', 'float64', 'float64')
-        check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
-                                  'float16', 'float32')
-        check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
-                                  'float16', 'float32', 'float32')
-        check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
-                                  'float32', 'float64')
-        check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
-                                  'float32', 'float64', 'float64')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 
'float32')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 
'float32', 'float32')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 
'float64')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 
'float64', 'float64')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 
'float32')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 
'float32', 'float32')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 
'float64')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 
'float64', 'float64')
+    check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
+                              'float16', 'float32')
+    check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
+                              'float16', 'float32', 'float32')
+    check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
+                              'float32', 'float64')
+    check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
+                              'float32', 'float64', 'float64')
 
 
 @with_seed()
@@ -6481,12 +6471,12 @@ def _gemm_test_helper(dtype, grad_check, rtol_fw = 
None, atol_fw = None,
 @with_seed()
 def test_gemm():
     _gemm_test_helper(np.float64, True)
-    os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "0"
-    _gemm_test_helper(np.float32, True)
-    if default_context().device_type == 'gpu':
-        os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "1"
+    with environment('MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION', '0'):
         _gemm_test_helper(np.float32, True)
-        os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "0"
+    if default_context().device_type == 'gpu':
+        with environment('MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION', '1'):
+            _gemm_test_helper(np.float32, True)
+
 
 # Helper functions for test_laop
 
diff --git a/tests/python/unittest/test_subgraph_op.py 
b/tests/python/unittest/test_subgraph_op.py
index 79a104f..ae46d25 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -21,7 +21,7 @@ import mxnet as mx
 from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, 
c_str, mx_real_t
 from mxnet.symbol import Symbol
 import numpy as np
-from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import assert_almost_equal, environment
 from mxnet import gluon
 from mxnet.gluon import nn
 from mxnet import nd
@@ -152,8 +152,8 @@ def check_subgraph_exe2(sym, subgraph_backend, op_names):
     and compare results of the partitioned sym and the original sym."""
     def get_executor(sym, subgraph_backend=None, op_names=None, 
original_exec=None):
         if subgraph_backend is not None:
-            os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
-            
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
+            with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
+                
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
                                                          
c_str_array(op_names)))
         exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
         input_names = sym.list_inputs()
@@ -166,12 +166,13 @@ def check_subgraph_exe2(sym, subgraph_backend, op_names):
                 exe.aux_dict[name][:] = 
mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
                     if original_exec is None else original_exec.aux_dict[name]
         exe.forward()
-        if subgraph_backend is not None:
-            
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
-            del os.environ['MXNET_SUBGRAPH_BACKEND']
         return exe
     original_exec = get_executor(sym)
-    partitioned_exec = get_executor(sym, subgraph_backend, op_names, 
original_exec)
+    with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
+        check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
+                                                     c_str_array(op_names)))
+        partitioned_exec = get_executor(sym, subgraph_backend, op_names, 
original_exec)
+        
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
     outputs1 = original_exec.outputs
     outputs2 = partitioned_exec.outputs
     assert len(outputs1) == len(outputs2)
@@ -209,10 +210,6 @@ def check_subgraph_exe4(sym, subgraph_backend, op_names):
     """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph 
partitioning in bind
     and compare results of the partitioned sym and the original sym."""
     def get_executor(sym, subgraph_backend=None, op_names=None, 
original_exec=None):
-        if subgraph_backend is not None:
-            os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
-            
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
-                                                         
c_str_array(op_names)))
         arg_shapes, _, aux_shapes = sym.infer_shape()
         if subgraph_backend is None:
             arg_array = [mx.nd.random.uniform(shape=shape) for shape in 
arg_shapes]
@@ -225,13 +222,14 @@ def check_subgraph_exe4(sym, subgraph_backend, op_names):
                        aux_states=aux_array if subgraph_backend is None else 
original_exec.aux_arrays,
                        grad_req='null')
         exe.forward()
-        if subgraph_backend is not None:
-            
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
-            del os.environ['MXNET_SUBGRAPH_BACKEND']
         return exe
 
     original_exec = get_executor(sym)
-    partitioned_exec = get_executor(sym, subgraph_backend, op_names, 
original_exec)
+    with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
+        check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), 
mx_uint(len(op_names)),
+                                                     c_str_array(op_names)))
+        partitioned_exec = get_executor(sym, subgraph_backend, op_names, 
original_exec)
+        
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
     outputs1 = original_exec.outputs
     outputs2 = partitioned_exec.outputs
     assert len(outputs1) == len(outputs2)
diff --git a/tests/python/unittest/test_symbol.py 
b/tests/python/unittest/test_symbol.py
index 793b920..a54bcf3 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -25,7 +25,7 @@ import mxnet as mx
 import numpy as np
 from common import assertRaises, models, TemporaryDirectory
 from mxnet.base import NotImplementedForSymbol
-from mxnet.test_utils import discard_stderr, rand_shape_nd, use_np
+from mxnet.test_utils import discard_stderr, rand_shape_nd, use_np, environment
 from mxnet.util import np_shape
 import pickle as pkl
 
@@ -389,15 +389,6 @@ def test_gen_atomic_symbol_multiple_outputs():
 
 
 def test_eliminate_common_expr():
-    if not sys.platform.startswith('linux'):
-        logging.info("Bypass the CSE test on non-Linux OS as setting env 
variables during test does not work on Windows")
-        return
-    def set_back_env_var(var_name, old_env_var):
-        if old_env_var is None:
-            os.environ.pop(var_name)
-        else:
-            os.environ[var_name] = old_env_var
-
     # helper function to test a single model
     def check_cse_on_symbol(sym, expected_savings, check_data, **kwargs):
         inputs = sym.list_inputs()
@@ -410,41 +401,36 @@ def test_eliminate_common_expr():
                 'float32' : 1e-7,
                 'float64' : 1e-7,
                 }
-        env_var_name = 'MXNET_ELIMINATE_COMMON_EXPR'
-        old_env_var = os.environ.get(env_var_name, None)
-        try:
-            for dtype in ['float16', 'float32', 'float64']:
-                data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
-                for grad_req in ['write', 'add']:
-                    type_dict = {inp : dtype for inp in inputs}
-                    os.environ[env_var_name] = '0'
+        for dtype in ['float16', 'float32', 'float64']:
+            data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
+            for grad_req in ['write', 'add']:
+                type_dict = {inp : dtype for inp in inputs}
+                with environment({'MXNET_ELIMINATE_COMMON_EXPR': '0'}):
                     orig_exec = sym.simple_bind(ctx=mx.cpu(0), 
grad_req=grad_req,
                                                 type_dict=type_dict, **shapes)
-                    os.environ[env_var_name] = '1'
+                with environment({'MXNET_ELIMINATE_COMMON_EXPR': '1'}):
                     cse_exec = sym.simple_bind(ctx=mx.cpu(0), 
grad_req=grad_req,
                                                type_dict=type_dict, **shapes)
-                    fwd_orig = orig_exec.forward(is_train=True, **data)
-                    out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
-                    orig_exec.backward(out_grads=out_grads)
-                    fwd_cse = cse_exec.forward(is_train=True, **data)
-                    cse_exec.backward(out_grads=out_grads)
-                    if check_data:
-                        for orig, cse in zip(fwd_orig, fwd_cse):
-                            np.testing.assert_allclose(orig.asnumpy(), 
cse.asnumpy(),
-                                                       rtol=rtol[dtype], 
atol=atol[dtype])
-                        for orig, cse in zip(orig_exec.grad_arrays, 
cse_exec.grad_arrays):
-                            if orig is None and cse is None:
-                                continue
-                            assert orig is not None
-                            assert cse is not None
-                            np.testing.assert_allclose(orig.asnumpy(), 
cse.asnumpy(),
-                                                       rtol=rtol[dtype], 
atol=atol[dtype])
-                    orig_sym_internals = 
orig_exec.get_optimized_symbol().get_internals()
-                    cse_sym_internals = 
cse_exec.get_optimized_symbol().get_internals()
-                    # test that the graph has been simplified as expected
-                    assert (len(cse_sym_internals) + expected_savings) == 
len(orig_sym_internals)
-        finally:
-            set_back_env_var(env_var_name, old_env_var)
+                fwd_orig = orig_exec.forward(is_train=True, **data)
+                out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
+                orig_exec.backward(out_grads=out_grads)
+                fwd_cse = cse_exec.forward(is_train=True, **data)
+                cse_exec.backward(out_grads=out_grads)
+                if check_data:
+                    for orig, cse in zip(fwd_orig, fwd_cse):
+                        np.testing.assert_allclose(orig.asnumpy(), 
cse.asnumpy(),
+                                                   rtol=rtol[dtype], 
atol=atol[dtype])
+                    for orig, cse in zip(orig_exec.grad_arrays, 
cse_exec.grad_arrays):
+                        if orig is None and cse is None:
+                            continue
+                        assert orig is not None
+                        assert cse is not None
+                        np.testing.assert_allclose(orig.asnumpy(), 
cse.asnumpy(),
+                                                   rtol=rtol[dtype], 
atol=atol[dtype])
+                orig_sym_internals = 
orig_exec.get_optimized_symbol().get_internals()
+                cse_sym_internals = 
cse_exec.get_optimized_symbol().get_internals()
+                # test that the graph has been simplified as expected
+                assert (len(cse_sym_internals) + expected_savings) == 
len(orig_sym_internals)
 
     a = mx.sym.Variable('a')
     b = mx.sym.Variable('b')

Reply via email to