This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f14b2eb  fix print_summary bug and add groups of convolution (#9492)
f14b2eb is described below

commit f14b2ebcb5475b0cde205648d0b80e3378e3dbed
Author: chinakook <chinak...@msn.com>
AuthorDate: Tue Feb 20 04:11:43 2018 +0800

    fix print_summary bug and add groups of convolution (#9492)
    
    * fix print_summary bug and add groups of convolution
    
    1. fix "int(node["attrs"]["no_bias"])" bug
    2. add groups of convolution param calculation
    
    * Update visualization.py
    
    lint
    
    * Update visualization.py
    
    * Update visualization.py
    
    * Update visualization.py
    
    * Update visualization.py
    
    * Update visualization.py
    
    * Update visualization.py
---
 python/mxnet/visualization.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 64ac77e..2b9da15 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -134,20 +134,24 @@ def print_summary(symbol, shape=None, line_length=120, 
positions=[.44, .64, .74,
                             pre_filter = pre_filter + int(shape[0])
         cur_param = 0
         if op == 'Convolution':
-            if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
-                cur_param = pre_filter * int(node["attrs"]["num_filter"])
+            if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 
'True':
+                num_group = int(node['attrs'].get('num_group', '1'))
+                cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
+                   // num_group
                 for k in _str2tuple(node["attrs"]["kernel"]):
                     cur_param *= int(k)
             else:
-                cur_param = pre_filter * int(node["attrs"]["num_filter"])
+                num_group = int(node['attrs'].get('num_group', '1'))
+                cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
+                   // num_group
                 for k in _str2tuple(node["attrs"]["kernel"]):
                     cur_param *= int(k)
                 cur_param += int(node["attrs"]["num_filter"])
         elif op == 'FullyConnected':
-            if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
-                cur_param = pre_filter * (int(node["attrs"]["num_hidden"]))
+            if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 
'True':
+                cur_param = pre_filter * int(node["attrs"]["num_hidden"])
             else:
-                cur_param = (pre_filter+1) * (int(node["attrs"]["num_hidden"]))
+                cur_param = (pre_filter+1) * int(node["attrs"]["num_hidden"])
         elif op == 'BatchNorm':
             key = node["name"] + "_output"
             if show_shape:

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to