This is an automated email from the ASF dual-hosted git repository.

indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 865255a  fix cnn visualization tutorial (#12719)
865255a is described below

commit 865255a398223510d857820f3061a26065bcdb74
Author: Thomas Delteil <[email protected]>
AuthorDate: Wed Oct 10 11:31:14 2018 -0700

    fix cnn visualization tutorial (#12719)
---
 docs/tutorials/vision/cnn_visualization.md | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/docs/tutorials/vision/cnn_visualization.md 
b/docs/tutorials/vision/cnn_visualization.md
index ea027df..940c261 100644
--- a/docs/tutorials/vision/cnn_visualization.md
+++ b/docs/tutorials/vision/cnn_visualization.md
@@ -99,12 +99,18 @@ def get_vgg(num_layers, ctx=mx.cpu(), 
root=os.path.join('~', '.mxnet', 'models')
     # Get the number of convolution layers and filters
     layers, filters = vgg_spec[num_layers]
 
-    # Build the VGG network
+    # Build the modified VGG network
     net = VGG(layers, filters, **kwargs)
-
-    # Load pretrained weights from model zoo
-    from mxnet.gluon.model_zoo.model_store import get_model_file
-    net.load_params(get_model_file('vgg%d' % num_layers, root=root), ctx=ctx)
+    net.initialize(ctx=ctx)
+    
+    # Get the pretrained model
+    vgg = mx.gluon.model_zoo.vision.get_vgg(num_layers, pretrained=True, 
ctx=ctx)
+    
+    # Set the parameters in the new network
+    params = vgg.collect_params()
+    for key in params:
+        param = params[key]
+        net.collect_params()[net.prefix+key.replace(vgg.prefix, 
'')].set_data(param.data())
 
     return net
 

Reply via email to