This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b7a6d9  Re-enable the test_gpu_memory_profiler_gluon test case 
(#18704)
5b7a6d9 is described below

commit 5b7a6d979b3bcf43416be25a6e47e0fd150daa54
Author: Bojian Zheng <[email protected]>
AuthorDate: Tue Sep 8 03:49:08 2020 -0400

    Re-enable the test_gpu_memory_profiler_gluon test case (#18704)
    
    * Re-enable the test_gpu_memory_profiler_gluon test case
    
    * Change the naming of head gradients
---
 config/linux_gpu.cmake                |   1 +
 python/mxnet/gluon/block.py           |   4 +-
 python/mxnet/gluon/parameter.py       |   2 +-
 python/mxnet/profiler.py              |   6 +-
 python/mxnet/symbol/symbol.py         |  24 +++---
 src/imperative/imperative.cc          |  14 +++-
 src/profiler/storage_profiler.cc      |  11 ++-
 tests/python/gpu/test_profiler_gpu.py | 145 ++++++++++++++++++++++------------
 8 files changed, 137 insertions(+), 70 deletions(-)

diff --git a/config/linux_gpu.cmake b/config/linux_gpu.cmake
index c75d294..50932d8 100644
--- a/config/linux_gpu.cmake
+++ b/config/linux_gpu.cmake
@@ -129,4 +129,5 @@ set(USE_INT64_TENSOR_SIZE OFF CACHE BOOL "Use int64_t to 
represent the total num
 # Other GPU features
 set(USE_NCCL "Use NVidia NCCL with CUDA" OFF)
 set(NCCL_ROOT "" CACHE BOOL "NCCL install path. Supports autodetection.")
+set(USE_NVML OFF CACHE BOOL "Build with NVML support")
 set(USE_NVTX ON CACHE BOOL "Build with NVTX support")
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 0a1c758..3178c20 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -947,8 +947,8 @@ class HybridBlock(Block):
                     flatten_inputs.append(None)
             grouped_inputs = _regroup(flatten_inputs, self._in_format)
 
-            params = {i: j.var() for i, j in self._reg_params.items()}
             with _block_scope(self):
+                params = {i: j.var() for i, j in self._reg_params.items()}
                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  
# pylint: disable=no-value-for-parameter
             out, self._out_format = _flatten(out, "output")
 
@@ -1447,8 +1447,8 @@ class HybridBlock(Block):
 
                 return self.hybrid_forward(ndarray, x, *args, **params)
 
-        params = {i: j.var() for i, j in self._reg_params.items()}
         with _block_scope(self):
+            params = {i: j.var() for i, j in self._reg_params.items()}
             return self.hybrid_forward(symbol, x, *args, **params)
 
     def hybrid_forward(self, F, x, *args, **kwargs):
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 2f1f115..68e860b 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -644,7 +644,7 @@ class Parameter(object):
         """Returns a symbol representing this parameter."""
         if self._var is None:
             if self._var_name is None:  # _var_name is set manually in 
SymbolBlock.import
-                self._var_name = self._uuid
+                self._var_name = self._uuid.replace('-', '_') + '_' + 
self._name
 
             self._var = symbol.var(self._var_name, shape=self.shape, 
dtype=self.dtype,
                                    lr_mult=self.lr_mult, wd_mult=self.wd_mult,
diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py
index 1b9583e..78a7dfc 100644
--- a/python/mxnet/profiler.py
+++ b/python/mxnet/profiler.py
@@ -504,7 +504,7 @@ class Marker(object):
 
 
 @contextlib.contextmanager
-def scope(name='<unk>:', append_mode=False):
+def scope(name='<unk>:', append_mode=True):
     """Assign the profiler scope for the GPU memory profiler.
 
     It is implicitly invoked when the Gluon API is used.
@@ -516,7 +516,9 @@ def scope(name='<unk>:', append_mode=False):
 
     """
     name = name + ":" if not name.endswith(":") else name
-    token = _current_scope.set(_current_scope.get() + name if append_mode else 
name)
+    if append_mode and _current_scope.get() != "<unk>:":
+        name = _current_scope.get() + name
+    token = _current_scope.set(name)
     # Invoke the C API to propagate the profiler scope information to the
     # C++ backend.
     check_call(_LIB.MXSetProfilerScope(c_str(name)))
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 2eebfdf..ddc4a1e 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -43,7 +43,8 @@ from . import _internal
 from . import op
 from ._internal import SymbolBase, _set_symbol_class
 from ..util import is_np_shape
-from ..profiler import _current_scope as _profiler_scope
+from ..profiler import scope as _profiler_scope
+from ..profiler import _current_scope as _current_profiler_scope
 
 __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
            "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
@@ -1782,15 +1783,16 @@ class Symbol(SymbolBase):
                     index = aux_names.index(name)
                     aux_states[index] = aux_states[index].totype(stype)
 
-        if grad_req == 'null':
-            args_grad = None
-        elif isinstance(grad_req, dict):
-            args_grad = {}
-            for i, name in enumerate(arg_names):
-                if grad_req[name] != 'null':
-                    args_grad[name] = args[i].copy()
-        else:
-            args_grad = [x.copy() for x in args]
+        with _profiler_scope("symbol:arg_grad:"):
+            if grad_req == 'null':
+                args_grad = None
+            elif isinstance(grad_req, dict):
+                args_grad = {}
+                for i, name in enumerate(arg_names):
+                    if grad_req[name] != 'null':
+                        args_grad[name] = args[i].copy()
+            else:
+                args_grad = [x.copy() for x in args]
         return Executor(self, ctx, args, args_grad, grad_req, aux_states)
 
     def _bind(self, ctx, args, args_grad=None, grad_req='write',
@@ -2728,7 +2730,7 @@ def var(name, attr=None, shape=None, lr_mult=None, 
wd_mult=None, dtype=None,
     if profiler_scope is not None:
         attr['__profiler_scope__'] = profiler_scope
     else:
-        attr['__profiler_scope__'] = _profiler_scope.get()
+        attr['__profiler_scope__'] = _current_profiler_scope.get()
     for k, v in kwargs.items():
         if k.startswith('__') and k.endswith('__'):
             attr[k] = str(v)
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 8702d93..e2028b0 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -233,7 +233,7 @@ void Imperative::RecordOp(
 
   nnvm::ObjectPtr node = nnvm::Node::Create();
   node->attrs = std::move(attrs);
-  node->attrs.name = "node_" + std::to_string(node_count_++);
+  node_count_ += 1;
   AGInfo& info = AGInfo::Create(node);
   info.state = state;
   info.ctx = outputs[0]->ctx();
@@ -322,7 +322,7 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs 
&&attrs,
   }
   node->attrs = std::move(attrs);
   // Need to support NameManager in imperative API to better name 
node->attrs.name
-  node->attrs.name = "node_" + std::to_string(node_count_++);
+  node_count_ += 1;
 
   for (uint32_t i = 0; i < outputs.size(); ++i) {
     outputs[i]->deferredcompute_entry_ = nnvm::NodeEntry{node, i, 0};
@@ -598,6 +598,16 @@ std::vector<NDArray*> Imperative::Backward(
     }
   }
 
+  for (size_t nid = num_forward_nodes;
+       nid < idx.num_nodes(); ++nid) {
+    const nnvm::NodeAttrs& attrs = idx[nid].source->attrs;
+    for (size_t oid = 0; oid < idx[nid].source->num_outputs(); ++oid) {
+      size_t eid = idx.entry_id(nid, oid);
+      arrays[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs),
+                                     attrs.name);
+    }
+  }  // for (nid ∈ [num_forward_nodes, idx.num_nodes()))
+
   if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) {
     common::LogMemoryPlan(graph);
   }
diff --git a/src/profiler/storage_profiler.cc b/src/profiler/storage_profiler.cc
index 5bbfa59..b0025a9 100644
--- a/src/profiler/storage_profiler.cc
+++ b/src/profiler/storage_profiler.cc
@@ -23,6 +23,7 @@
 #endif  // MXNET_USE_NVML
 #include <fstream>
 #include <map>
+#include <regex>
 #include <unordered_map>
 #include <vector>
 #include "./profiler.h"
@@ -61,11 +62,17 @@ void GpuDeviceStorageProfiler::DumpProfile() const {
   std::multimap<std::string, AllocEntryDumpFmt> gpu_mem_ordered_alloc_entries;
   // map the GPU device ID to the total amount of allocations
   std::unordered_map<int, size_t> gpu_dev_id_total_alloc_map;
+  std::regex gluon_param_regex("([0-9a-fA-F]{8})_([0-9a-fA-F]{4})_"
+                               "([0-9a-fA-F]{4})_([0-9a-fA-F]{4})_"
+                               "([0-9a-fA-F]{12})_([^ ]*)");
+
   for (const std::pair<void *const, AllocEntry>& alloc_entry :
        gpu_mem_alloc_entries_) {
+    std::string alloc_entry_name
+        = std::regex_replace(alloc_entry.second.name, gluon_param_regex, "$6");
     gpu_mem_ordered_alloc_entries.emplace(
-        alloc_entry.second.profiler_scope +
-        alloc_entry.second.name, AllocEntryDumpFmt{
+        alloc_entry.second.profiler_scope + alloc_entry_name,
+        AllocEntryDumpFmt{
           alloc_entry.second.requested_size,
           alloc_entry.second.dev_id,
           alloc_entry.second.actual_size,
diff --git a/tests/python/gpu/test_profiler_gpu.py 
b/tests/python/gpu/test_profiler_gpu.py
index 89eb425..05bec4a 100644
--- a/tests/python/gpu/test_profiler_gpu.py
+++ b/tests/python/gpu/test_profiler_gpu.py
@@ -15,25 +15,24 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import csv
 import os
 import sys
 
+import numpy as np
 import mxnet as mx
 mx.test_utils.set_default_context(mx.gpu(0))
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-# We import all tests from ../unittest/test_profiler.py
-# They will be detected by test framework, as long as the current file has a 
different filename
-from test_profiler import *
+from mxnet import profiler
+from mxnet.gluon import nn
+from mxnet.gluon.block import _block_scope
+from test_profiler import enable_profiler
 
-# Test seen to crash pytest worker during development of 
https://github.com/apache/incubator-mxnet/pull/18694
-del test_aggregate_duplication
 
 def test_gpu_memory_profiler_symbolic():
-    iter_num = 5
-
-    enable_profiler('test_profiler.json', False, False)
+    enable_profiler('test_profiler.json')
     profiler.set_state('run')
 
     with profiler.scope("tensordot"):
@@ -41,18 +40,19 @@ def test_gpu_memory_profiler_symbolic():
         B = mx.sym.Variable('B')
         C = mx.symbol.dot(A, B, name='dot')
 
-    executor = C._simple_bind(mx.gpu(), 'write', A=(4096, 4096), B=(4096, 
4096))
+    executor = C._simple_bind(mx.gpu(), 'write', A=(1024, 2048), B=(2048, 
4096))
 
-    a = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
-    b = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
+    with profiler.scope("init"):
+        a = mx.random.uniform(-1.0, 1.0, shape=(1024, 2048))
+        b = mx.random.uniform(-1.0, 1.0, shape=(2048, 4096))
 
     a.copyto(executor.arg_dict['A'])
     b.copyto(executor.arg_dict['B'])
 
-    for i in range(iter_num):
-        executor.forward()
-        c = executor.outputs[0]
-        mx.nd.waitall()
+    executor.forward()
+    executor.backward()
+    c = executor.outputs[0]
+    mx.nd.waitall()
     profiler.set_state('stop')
     profiler.dump(True)
 
@@ -62,41 +62,53 @@ def test_gpu_memory_profiler_symbolic():
             {'Attribute Name' : 'tensordot:in_arg:B',
              'Requested Size' : str(4 * b.size)},
             {'Attribute Name' : 'tensordot:dot',
-             'Requested Size' : str(4 * c.size)}]
+             'Requested Size' : str(4 * c.size)},
+            {'Attribute Name' : 'init:_random_uniform',
+             'Requested Size' : str(4 * a.size)},
+            {'Attribute Name' : 'init:_random_uniform',
+             'Requested Size' : str(4 * b.size)}]
 
     # Sample gpu_memory_profile.csv:
     # "Attribute Name","Requested Size","Device","Actual Size","Reuse?"
-    # "<unk>:_zeros","67108864","0","67108864","0"
-    # "<unk>:_zeros","67108864","0","67108864","0"
-    # "tensordot:dot","67108864","0","67108864","1"
-    # "tensordot:dot","67108864","0","67108864","1"
-    # "tensordot:in_arg:A","67108864","0","67108864","0"
-    # "tensordot:in_arg:B","67108864","0","67108864","0"
-    # "nvml_amend","1074790400","0","1074790400","0"
+    # init:_random_uniform,33554432,0,33554432,1
+    # init:_random_uniform,8388608,0,8388608,1
+    # resource:temp_space (sample_op.h +365),8,0,4096,0
+    # symbol:arg_grad:unknown,8388608,0,8388608,0
+    # symbol:arg_grad:unknown,33554432,0,33554432,0
+    # tensordot:dot,16777216,0,16777216,0
+    # tensordot:dot_backward,33554432,0,33554432,0
+    # tensordot:dot_backward,8388608,0,8388608,0
+    # tensordot:dot_head_grad,16777216,0,16777216,0
+    # tensordot:in_arg:A,8388608,0,8388608,0
+    # tensordot:in_arg:B,33554432,0,33554432,0
 
     with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as 
csv_file:
         csv_reader = csv.DictReader(csv_file)
+        # TODO: Remove this print statement later on.
+        for row in csv_reader:
+            print(",".join(list(row.values())))
         for expected_alloc_entry in expected_alloc_entries:
             csv_file.seek(0)
             entry_found = False
             for row in csv_reader:
-                if row['Attribute Name'] == expected_alloc_entry['Attribute 
Name']:
-                    assert row['Requested Size'] == 
expected_alloc_entry['Requested Size'], \
-                           "requested size={} is not equal to the expected 
size={}" \
-                           .format(row['Requested Size'],
-                                   expected_alloc_entry['Requested Size'])
+                if row['Attribute Name'] == expected_alloc_entry['Attribute 
Name'] and \
+                   row['Requested Size'] == expected_alloc_entry['Requested 
Size']:
                     entry_found = True
                     break
             assert entry_found, \
-                   "Entry for attr_name={} has not been found" \
-                   .format(expected_alloc_entry['Attribute Name'])
+                    "Entry for (attr_name={}, alloc_size={}) has not been 
found" \
+                    .format(expected_alloc_entry['Attribute Name'],
+                            expected_alloc_entry['Requested Size'])
+        # Make sure that there is no unknown allocation entry.
+        csv_file.seek(0)
+        for row in csv_reader:
+            if row['Attribute Name'] == "<unk>:unknown" or \
+               row['Attribute Name'] == "<unk>:":
+                assert False, "Unknown allocation entry has been encountered"
 
 
[email protected](is_cd_run(), reason="flaky test - open issue #18564")
[email protected](reason='https://github.com/apache/incubator-mxnet/issues/18564')
 def test_gpu_memory_profiler_gluon():
-    enable_profiler(profile_filename='test_profiler.json',
-                    run=True, continuous_dump=True)
+    enable_profiler(profile_filename='test_profiler.json')
     profiler.set_state('run')
 
     model = nn.HybridSequential()
@@ -117,31 +129,64 @@ def test_gpu_memory_profiler_gluon():
     profiler.set_state('stop')
     profiler.dump(True)
 
+    # Sample gpu_memory_profile.csv:
+    # "Attribute Name","Requested Size","Device","Actual Size","Reuse?"
+    # <unk>:in_arg:data,640,0,4096,0
+    # 
hybridsequential:activation0:hybridsequential_activation0_fwd,2048,0,4096,0
+    # 
hybridsequential:activation0:hybridsequential_activation0_fwd_backward,8192,0,8192,0
+    # 
hybridsequential:activation0:hybridsequential_activation0_fwd_head_grad,2048,0,4096,0
+    # 
hybridsequential:dense0:activation0:hybridsequential_dense0_activation0_fwd,8192,0,8192,0
+    # hybridsequential:dense0:arg_grad:bias,512,0,4096,0
+    # hybridsequential:dense0:arg_grad:weight,5120,0,8192,0
+    # hybridsequential:dense0:hybridsequential_dense0_fwd,8192,0,8192,0
+    # hybridsequential:dense0:in_arg:bias,512,0,4096,0
+    # hybridsequential:dense0:in_arg:weight,5120,0,8192,0
+    # 
hybridsequential:dense1:activation0:hybridsequential_dense1_activation0_fwd,4096,0,4096,0
+    # hybridsequential:dense1:arg_grad:bias,256,0,4096,0
+    # hybridsequential:dense1:arg_grad:weight,32768,0,32768,0
+    # hybridsequential:dense1:hybridsequential_dense1_fwd,4096,0,4096,0
+    # hybridsequential:dense1:in_arg:bias,256,0,4096,0
+    # hybridsequential:dense1:in_arg:weight,32768,0,32768,0
+    # hybridsequential:dense2:arg_grad:bias,128,0,4096,0
+    # hybridsequential:dense2:arg_grad:weight,8192,0,8192,0
+    # 
hybridsequential:dense2:hybridsequential_dense2_fwd_backward,4096,0,4096,1
+    # hybridsequential:dense2:in_arg:bias,128,0,4096,0
+    # hybridsequential:dense2:in_arg:weight,8192,0,8192,0
+    # hybridsequential:dropout0:hybridsequential_dropout0_fwd,8192,0,8192,0
+    # hybridsequential:dropout0:hybridsequential_dropout0_fwd,8192,0,8192,0
+    # resource:cudnn_dropout_state (dropout-inl.h +256),1474560,0,1474560,0
+    # resource:temp_space (fully_connected-inl.h +316),15360,0,16384,0
+
     # We are only checking for weight parameters here, also making sure that
     # there is no unknown entries in the memory profile.
     with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as 
csv_file:
         csv_reader = csv.DictReader(csv_file)
+        # TODO: Remove this print statement later on.
         for row in csv_reader:
             print(",".join(list(row.values())))
-        for scope in ['in_arg', 'arg_grad']:
-            for key, nd in model.collect_params().items():
-                expected_arg_name = "%s:%s:" % (model.name, scope) + nd.name
-                expected_arg_size = str(4 * np.prod(nd.shape))
-                csv_file.seek(0)
-                entry_found = False
-                for row in csv_reader:
-                    if row['Attribute Name'] == expected_arg_name:
-                        assert row['Requested Size'] == expected_arg_size, \
-                            "requested size={} is not equal to the expected 
size={}" \
-                            .format(row['Requested Size'], expected_arg_size)
-                        entry_found = True
-                        break
-                assert entry_found, \
-                    "Entry for attr_name={} has not been found" \
-                    .format(expected_arg_name)
+        for param in model.collect_params().values():
+            expected_arg_name = "%sin_arg:" % 
param.var().attr('__profiler_scope__') + \
+                                param.name
+            expected_arg_size = str(4 * np.prod(param.shape))
+            csv_file.seek(0)
+            entry_found = False
+            for row in csv_reader:
+                if row['Attribute Name'] == expected_arg_name and \
+                   row['Requested Size'] == expected_arg_size:
+                    entry_found = True
+                    break
+            assert entry_found, \
+                    "Entry for (attr_name={}, alloc_size={}) has not been 
found" \
+                        .format(expected_arg_name,
+                                expected_arg_size)
         # Make sure that there is no unknown allocation entry.
         csv_file.seek(0)
         for row in csv_reader:
             if row['Attribute Name'] == "<unk>:unknown" or \
                row['Attribute Name'] == "<unk>:":
                 assert False, "Unknown allocation entry has been encountered"
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()

Reply via email to