This is an automated email from the ASF dual-hosted git repository.
indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 865255a fix cnn visualization tutorial (#12719)
865255a is described below
commit 865255a398223510d857820f3061a26065bcdb74
Author: Thomas Delteil <[email protected]>
AuthorDate: Wed Oct 10 11:31:14 2018 -0700
fix cnn visualization tutorial (#12719)
---
docs/tutorials/vision/cnn_visualization.md | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/docs/tutorials/vision/cnn_visualization.md
b/docs/tutorials/vision/cnn_visualization.md
index ea027df..940c261 100644
--- a/docs/tutorials/vision/cnn_visualization.md
+++ b/docs/tutorials/vision/cnn_visualization.md
@@ -99,12 +99,18 @@ def get_vgg(num_layers, ctx=mx.cpu(),
root=os.path.join('~', '.mxnet', 'models')
# Get the number of convolution layers and filters
layers, filters = vgg_spec[num_layers]
- # Build the VGG network
+ # Build the modified VGG network
net = VGG(layers, filters, **kwargs)
-
- # Load pretrained weights from model zoo
- from mxnet.gluon.model_zoo.model_store import get_model_file
- net.load_params(get_model_file('vgg%d' % num_layers, root=root), ctx=ctx)
+ net.initialize(ctx=ctx)
+
+ # Get the pretrained model
+ vgg = mx.gluon.model_zoo.vision.get_vgg(num_layers, pretrained=True,
ctx=ctx)
+
+ # Set the parameters in the new network
+ params = vgg.collect_params()
+ for key in params:
+ param = params[key]
+ net.collect_params()[net.prefix+key.replace(vgg.prefix,
'')].set_data(param.data())
return net