This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b929892  add naming tutorial (#10511)
b929892 is described below

commit b9298924043b9c8316613f5bdf9bc467cb140197
Author: Eric Junyuan Xie <piiswr...@users.noreply.github.com>
AuthorDate: Fri Apr 13 10:06:39 2018 -0700

    add naming tutorial (#10511)
    
    * add naming tutorial
    
    * fix
    
    * Update naming.md
    
    * Update index.md
    
    * fix save load
    
    * fix
    
    * fix
    
    * fix
    
    * fix
---
 docs/tutorials/gluon/datasets.md                |  14 +-
 docs/tutorials/gluon/gluon.md                   |   2 +-
 docs/tutorials/gluon/hybrid.md                  |   4 +-
 docs/tutorials/gluon/mnist.md                   |   5 +-
 docs/tutorials/gluon/naming.md                  | 255 ++++++++++++++++++++++++
 docs/tutorials/index.md                         |   2 +
 docs/tutorials/onnx/fine_tuning_gluon.md        |  24 +--
 example/gluon/embedding_learning/train.py       |   2 +-
 example/gluon/kaggle_k_fold_cross_validation.py |   2 +-
 example/gluon/learning_rate_manipulation.py     |   6 +-
 example/gluon/lstm_crf.py                       |   2 +-
 example/gluon/style_transfer/main.py            |  16 +-
 example/gluon/super_resolution.py               |   2 +-
 example/gluon/tree_lstm/main.py                 |   4 +-
 python/mxnet/gluon/block.py                     | 132 ++++++++----
 python/mxnet/gluon/contrib/nn/basic_layers.py   |   4 +-
 python/mxnet/gluon/nn/basic_layers.py           |  21 +-
 python/mxnet/gluon/parameter.py                 |  37 ++--
 python/mxnet/gluon/rnn/rnn_cell.py              |  28 +--
 python/mxnet/gluon/utils.py                     |   7 +
 tests/python/unittest/test_gluon.py             |  14 +-
 21 files changed, 451 insertions(+), 132 deletions(-)

diff --git a/docs/tutorials/gluon/datasets.md b/docs/tutorials/gluon/datasets.md
index 248ea02..0c9b537 100644
--- a/docs/tutorials/gluon/datasets.md
+++ b/docs/tutorials/gluon/datasets.md
@@ -33,7 +33,7 @@ print(sample)
 
     (
      [ 0.4375872   0.29753461  0.89177299]
-     <NDArray 3 @cpu(0)>, 
+     <NDArray 3 @cpu(0)>,
      [ 0.83261985]
      <NDArray 1 @cpu(0)>)
 
@@ -60,7 +60,7 @@ for X_batch, y_batch in data_loader:
     X_batch has shape (5, 3), and y_batch has shape (5, 1)
 
 
-We can see 2 mini-batches of data (and labels), each with 5 samples, which 
makes sense given we started with a dataset of 10 samples. When comparing the 
shape of the batches to the samples returned by the 
[`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset),
 we've gained an extra dimension at the start which is sometimes called the 
batch axis. 
+We can see 2 mini-batches of data (and labels), each with 5 samples, which 
makes sense given we started with a dataset of 10 samples. When comparing the 
shape of the batches to the samples returned by the 
[`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset),
 we've gained an extra dimension at the start which is sometimes called the 
batch axis.
 
 Our `data_loader` loop will stop when every sample of `dataset` has been 
returned as part of a batch. Sometimes the dataset length isn't divisible by 
the mini-batch size, leaving a final batch with a smaller number of samples. 
[`DataLoader`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataloader#mxnet.gluon.data.DataLoader)'s
 default behavior is to return this smaller mini-batch, but this can be changed 
by setting the `last_batch` parameter to `discard` (which [...]
 
@@ -137,7 +137,7 @@ def construct_net():
 ctx = mx.cpu()
 net = construct_net()
 net.hybridize()
-net.collect_params().initialize(mx.init.Xavier())
+net.initialize(mx.init.Xavier())
 # define loss and trainer.
 criterion = gluon.loss.SoftmaxCrossEntropyLoss()
 trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
@@ -159,7 +159,7 @@ for epoch in range(epochs):
         cumulative_train_loss += loss.sum()
         training_samples += data.shape[0]
     train_loss = cumulative_train_loss.asscalar()/training_samples
-        
+
     # validation loop
     cumulative_valid_loss = mx.nd.array([0])
     valid_samples = 0
@@ -171,7 +171,7 @@ for epoch in range(epochs):
         cumulative_valid_loss += loss.sum()
         valid_samples += data.shape[0]
     valid_loss = cumulative_valid_loss.asscalar()/valid_samples
-        
+
     print("Epoch {}, training loss: {:.2f}, validation loss: 
{:.2f}".format(epoch, train_loss, valid_loss))
 ```
 
@@ -184,7 +184,7 @@ for epoch in range(epochs):
 
 # Using own data with included `Dataset`s
 
-Gluon has a number of different 
[`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset)
 classes for working with your own image data straight out-of-the-box. You can 
get started quickly using the 
[`mxnet.gluon.data.vision.datasets.ImageFolderDataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=imagefolderdataset#mxnet.gluon.data.vision.datasets.ImageFolderDataset)
 which loads images directly from a [...]
+Gluon has a number of different 
[`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset)
 classes for working with your own image data straight out-of-the-box. You can 
get started quickly using the 
[`mxnet.gluon.data.vision.datasets.ImageFolderDataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=imagefolderdataset#mxnet.gluon.data.vision.datasets.ImageFolderDataset)
 which loads images directly from a [...]
 
 We will run through an example for image classification, but a similar process 
applies for other vision tasks. If you already have your own collection of 
images to work with you should partition your data into training and test sets, 
and place all objects of the same class into seperate folders. Similar to:
 
@@ -307,4 +307,4 @@ data_iter_loader = DataIterLoader(data_iter)
 for X_batch, y_batch in data_iter_loader:
     assert X_batch.shape == (5, 3)
     assert y_batch.shape == (5, 1)
-```
\ No newline at end of file
+```
diff --git a/docs/tutorials/gluon/gluon.md b/docs/tutorials/gluon/gluon.md
index a1688ea..518e999 100644
--- a/docs/tutorials/gluon/gluon.md
+++ b/docs/tutorials/gluon/gluon.md
@@ -70,7 +70,7 @@ A network must be created and initialized before it can be 
used:
 net = Net()
 # Initialize on CPU. Replace with `mx.gpu(0)`, or `[mx.gpu(0), mx.gpu(1)]`,
 # etc to use one or more GPUs.
-net.collect_params().initialize(mx.init.Xavier(), ctx=mx.cpu())
+net.initialize(mx.init.Xavier(), ctx=mx.cpu())
 ```
 
 Note that because we didn't specify input size to layers in Net's constructor,
diff --git a/docs/tutorials/gluon/hybrid.md b/docs/tutorials/gluon/hybrid.md
index 859ad93..3554a15 100644
--- a/docs/tutorials/gluon/hybrid.md
+++ b/docs/tutorials/gluon/hybrid.md
@@ -77,7 +77,7 @@ is called, its `hybrid_forward` will be run:
 
 ```python
 net = Net()
-net.collect_params().initialize()
+net.initialize()
 x = mx.nd.random_normal(shape=(16, 1, 28, 28))
 net(x)
 x = mx.nd.random_normal(shape=(16, 1, 28, 28))
@@ -117,7 +117,7 @@ x = mx.sym.var('data')
 y = net(x)
 print(y)
 y.save('model.json')
-net.collect_params().save('model.params')
+net.save_params('model.params')
 ```
 
 If your network outputs more than one value, you can use `mx.sym.Group` to
diff --git a/docs/tutorials/gluon/mnist.md b/docs/tutorials/gluon/mnist.md
index 86c493b..3a2a2cb 100644
--- a/docs/tutorials/gluon/mnist.md
+++ b/docs/tutorials/gluon/mnist.md
@@ -102,7 +102,7 @@ initialized parameters.
 ```python
 gpus = mx.test_utils.list_gpus()
 ctx =  [mx.gpu()] if gpus else [mx.cpu(0), mx.cpu(1)]
-net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
+net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
 trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.02})
 ```
 
@@ -252,10 +252,9 @@ Training and prediction can be done in the similar way as 
we did for MLP.
 We will initialize the network parameters as follows:
 
 ```python
-
 # set the context on GPU is available otherwise CPU
 ctx = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
-net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
+net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
 trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})
 ```
 
diff --git a/docs/tutorials/gluon/naming.md b/docs/tutorials/gluon/naming.md
new file mode 100644
index 0000000..37b63fa
--- /dev/null
+++ b/docs/tutorials/gluon/naming.md
@@ -0,0 +1,255 @@
+
+# Naming of Gluon Parameter and Blocks
+
+In gluon, each Parameter or Block has a name (and prefix). Parameter names are 
specified by users and Block names can be either specified by users or 
automatically created.
+
+In this tutorial we talk about the best practices on naming. First, let's 
import MXNet and Gluon:
+
+
+```python
+from __future__ import print_function
+import mxnet as mx
+from mxnet import gluon
+```
+
+## Naming Blocks
+
+When creating a block, you can assign a prefix to it:
+
+
+```python
+mydense = gluon.nn.Dense(100, prefix='mydense_')
+print(mydense.prefix)
+```
+
+    mydense_
+
+
+When no prefix is given, Gluon will automatically generate one:
+
+
+```python
+dense0 = gluon.nn.Dense(100)
+print(dense0.prefix)
+```
+
+    dense0_
+
+
+When you create more Blocks of the same kind, they will be named with 
incrementing suffixes to avoid collision:
+
+
+```python
+dense1 = gluon.nn.Dense(100)
+print(dense1.prefix)
+```
+
+    dense1_
+
+
+## Naming Parameters
+
+Parameters within a Block will be named by prepending the prefix of the Block 
to the name of the Parameter:
+
+
+```python
+print(dense0.collect_params())
+```
+
+    dense0_ (
+      Parameter dense0_weight (shape=(100, 0), dtype=<type 'numpy.float32'>)
+      Parameter dense0_bias (shape=(100,), dtype=<type 'numpy.float32'>)
+    )
+
+
+## Name scopes
+
+To manage the names of nested Blocks, each Block has a `name_scope` attached 
to it. All Blocks created within a name scope will have its parent Block's 
prefix prepended to its name.
+
+Let's demonstrate this by first defining a simple neural net:
+
+
+```python
+class Model(gluon.Block):
+    def __init__(self, **kwargs):
+        super(Model, self).__init__(**kwargs)
+        with self.name_scope():
+            self.dense0 = gluon.nn.Dense(20)
+            self.dense1 = gluon.nn.Dense(20)
+            self.mydense = gluon.nn.Dense(20, prefix='mydense_')
+
+    def forward(self, x):
+        x = mx.nd.relu(self.dense0(x))
+        x = mx.nd.relu(self.dense1(x))
+        return mx.nd.relu(self.mydense(x))
+```
+
+Now let's instantiate our neural net.
+
+- Note that `model0.dense0` is named as `model0_dense0_` instead of `dense0_`.
+
+- Also note that although we specified `mydense_` as prefix for 
`model.mydense`, its parent's prefix is automatically prepended to generate the 
prefix `model0_mydense_`.
+
+
+```python
+model0 = Model()
+model0.initialize()
+model0(mx.nd.zeros((1, 20)))
+print(model0.prefix)
+print(model0.dense0.prefix)
+print(model0.dense1.prefix)
+print(model0.mydense.prefix)
+```
+
+    model0_
+    model0_dense0_
+    model0_dense1_
+    model0_mydense_
+
+
+If we instantiate `Model` again, it will be given a different name like shown 
before for `Dense`.
+
+- Note that `model1.dense0` is still named as `dense0_` instead of `dense2_`, 
following dense layers in previously created `model0`. This is because each 
instance of model's name scope is independent of each other.
+
+
+```python
+model1 = Model()
+print(model1.prefix)
+print(model1.dense0.prefix)
+print(model1.dense1.prefix)
+print(model1.mydense.prefix)
+```
+
+    model1_
+    model1_dense0_
+    model1_dense1_
+    model1_mydense_
+
+
+**It is recommended that you manually specify a prefix for the top level 
Block, i.e. `model = Model(prefix='mymodel_')`, to avoid potential confusions 
in naming.**
+
+The same principle also applies to container blocks like Sequential. 
`name_scope` can be used inside `__init__` as well as out side of `__init__`:
+
+
+```python
+net = gluon.nn.Sequential()
+with net.name_scope():
+    net.add(gluon.nn.Dense(20))
+    net.add(gluon.nn.Dense(20))
+print(net.prefix)
+print(net[0].prefix)
+print(net[1].prefix)
+```
+
+    sequential0_
+    sequential0_dense0_
+    sequential0_dense1_
+
+
+`gluon.model_zoo` also behaves similarly:
+
+
+```python
+net = gluon.nn.Sequential()
+with net.name_scope():
+    net.add(gluon.model_zoo.vision.alexnet(pretrained=True))
+    net.add(gluon.model_zoo.vision.alexnet(pretrained=True))
+print(net.prefix, net[0].prefix, net[1].prefix)
+```
+
+    sequential1_ sequential1_alexnet0_ sequential1_alexnet1_
+
+
+## Saving and loading
+
+Because model0 and model1 have different prefixes, their parameters also have 
different names:
+
+
+```python
+print(model0.collect_params(), '\n')
+print(model1.collect_params())
+```
+
+    model0_ (
+      Parameter model0_dense0_weight (shape=(20L, 20L), dtype=<type 
'numpy.float32'>)
+      Parameter model0_dense0_bias (shape=(20L,), dtype=<type 'numpy.float32'>)
+      Parameter model0_dense1_weight (shape=(20L, 20L), dtype=<type 
'numpy.float32'>)
+      Parameter model0_dense1_bias (shape=(20L,), dtype=<type 'numpy.float32'>)
+      Parameter model0_mydense_weight (shape=(20L, 20L), dtype=<type 
'numpy.float32'>)
+      Parameter model0_mydense_bias (shape=(20L,), dtype=<type 
'numpy.float32'>)
+    ) 
+    
+    model1_ (
+      Parameter model1_dense0_weight (shape=(20, 0), dtype=<type 
'numpy.float32'>)
+      Parameter model1_dense0_bias (shape=(20,), dtype=<type 'numpy.float32'>)
+      Parameter model1_dense1_weight (shape=(20, 0), dtype=<type 
'numpy.float32'>)
+      Parameter model1_dense1_bias (shape=(20,), dtype=<type 'numpy.float32'>)
+      Parameter model1_mydense_weight (shape=(20, 0), dtype=<type 
'numpy.float32'>)
+      Parameter model1_mydense_bias (shape=(20,), dtype=<type 'numpy.float32'>)
+    )
+
+
+As a result, if you try to save parameters from model0 and load it with 
model1, you'll get an error due to unmatching names:
+
+
+```python
+model0.collect_params().save('model.params')
+try:
+    model1.collect_params().load('model.params', mx.cpu())
+except Exception as e:
+    print(e)
+```
+
+    Parameter 'model1_dense0_weight' is missing in file 'model.params', which 
contains parameters: 'model0_mydense_weight', 'model0_dense1_bias', 
'model0_dense1_weight', 'model0_dense0_weight', 'model0_dense0_bias', 
'model0_mydense_bias'. Please make sure source and target networks have the 
same prefix.
+
+
+To solve this problem, we use `save_params`/`load_params` instead of 
`collect_params` and `save`/`load`. `save_params` uses model structure, instead 
of parameter name, to match parameters.
+
+
+```python
+model0.save_params('model.params')
+model1.load_params('model.params')
+print(mx.nd.load('model.params').keys())
+```
+
+    ['dense0.bias', 'mydense.bias', 'dense1.bias', 'dense1.weight', 
'dense0.weight', 'mydense.weight']
+
+
+## Replacing Blocks from networks and fine-tuning
+
+Sometimes you may want to load a pretrained model, and replace certain Blocks 
in it for fine-tuning.
+
+For example, the alexnet in model zoo has 1000 output dimensions, but maybe 
you only have 100 classes in your application.
+
+To see how to do this, we first load a pretrained AlexNet.
+
+- In Gluon model zoo, all image classification models follow the format where 
the feature extraction layers are named `features` while the output layer is 
named `output`.
+- Note that the output layer is a dense block with 1000 dimension outputs.
+
+
+```python
+alexnet = gluon.model_zoo.vision.alexnet(pretrained=True)
+print(alexnet.output)
+print(alexnet.output.prefix)
+```
+
+    Dense(4096 -> 1000, linear)
+    alexnet0_dense2_
+
+
+To change the output to 100 dimension, we replace it with a new block.
+
+
+```python
+with alexnet.name_scope():
+    alexnet.output = gluon.nn.Dense(100)
+alexnet.output.initialize()
+print(alexnet.output)
+print(alexnet.output.prefix)
+```
+
+    Dense(None -> 100, linear)
+    alexnet0_dense3_
+
+
+<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 00a1504..04b7893 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -100,6 +100,8 @@ The Gluon and Module tutorials are in Python, but you can 
also find a variety of
 
 - [Designing a custom layer with 
gluon](http://gluon.mxnet.io/chapter03_deep-neural-networks/custom-layer.html)
 
+- [Block and Parameter naming](/tutorials/gluon/naming.html)
+
 - [Fast, portable neural networks with Gluon 
HybridBlocks](http://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html)
 
 - [Training on multiple GPUs with 
gluon](http://gluon.mxnet.io/chapter07_distributed-learning/multiple-gpus-gluon.html)
diff --git a/docs/tutorials/onnx/fine_tuning_gluon.md 
b/docs/tutorials/onnx/fine_tuning_gluon.md
index 4116ff6..c301542 100644
--- a/docs/tutorials/onnx/fine_tuning_gluon.md
+++ b/docs/tutorials/onnx/fine_tuning_gluon.md
@@ -7,7 +7,7 @@ Fine-tuning is a common practice in Transfer Learning. One can 
take advantage of
 [Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides 
an open source format for AI models. It defines an extensible computation graph 
model, as well as definitions of built-in operators and standard data types.
 
 In this tutorial we will:
-    
+
 - learn how to pick a specific layer from a pre-trained .onnx model file
 - learn how to load this model in Gluon and fine-tune it on a different dataset
 
@@ -63,7 +63,7 @@ We download a pre-trained model, in our case the 
[vgg16](https://arxiv.org/abs/1
 
 
 ```python
-base_url = "https://s3.amazonaws.com/download.onnx/models/"; 
+base_url = "https://s3.amazonaws.com/download.onnx/models/";
 current_model = "vgg16"
 model_folder = "model"
 archive_file = "{}.tar.gz".format(current_model)
@@ -135,7 +135,7 @@ We transform the dataset images using the following 
operations:
 def transform(image, label):
     resized = mx.image.resize_short(image, EDGE)
     cropped, crop_info = mx.image.center_crop(resized, SIZE)
-    transposed = nd.transpose(cropped, (2,0,1)) 
+    transposed = nd.transpose(cropped, (2,0,1))
     return transposed, label
 ```
 
@@ -162,7 +162,7 @@ We use num_workers=Number of CPU cores, which means the 
dataloading and pre-proc
 ```python
 dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, 
last_batch='discard',
                               shuffle=True, num_workers=NUM_WORKERS)
-dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, 
last_batch='discard', 
+dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, 
last_batch='discard',
                              shuffle=True, num_workers=NUM_WORKERS)
 print("Train dataset: {} images, Test dataset: {} 
images".format(len(dataset_train), len(dataset_test)))
 ```
@@ -274,7 +274,7 @@ We create the new dense layer with the right new number of 
classes (101) and ini
 
 ```python
 dense_layer = gluon.nn.Dense(NUM_CLASSES)
-dense_layer.collect_params().initialize(mx.init.Xavier(magnitude=2.24), 
ctx=ctx)
+dense_layer.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
 ```
 
 We add the SymbolBlock and the new dense layer to a HybridSequential network
@@ -309,8 +309,8 @@ The trainer will retrain and fine-tune the entire network. 
If we use `dense_laye
 
 ```python
 trainer = gluon.Trainer(net.collect_params(), 'sgd', 
-                        {'learning_rate': LEARNING_RATE, 
-                         'wd':WDECAY, 
+                        {'learning_rate': LEARNING_RATE,
+                         'wd':WDECAY,
                          'momentum':MOMENTUM})
 ```
 
@@ -353,20 +353,20 @@ for epoch in range(20):
     for i, (data, label) in enumerate(dataloader_train):
         data = data.astype(np.float32).as_in_context(ctx)
         label = label.as_in_context(ctx)
-        
+
         if i%20==0 and i >0:
             print('Batch [{0}] loss: {1:.4f}'.format(i, 
loss.mean().asscalar()))
-        
+
         with autograd.record():
             output = net(data)
             loss = softmax_cross_entropy(output, label)
         loss.backward()
         trainer.step(data.shape[0])
-    
+
     nd.waitall() # wait at the end of the epoch    
     new_val_accuracy = evaluate_accuracy_gluon(dataloader_test, net)    
     print("Epoch [{0}] Test Accuracy {1:.4f} ".format(epoch, new_val_accuracy))
-    
+
     # We perform early-stopping regularization, to prevent the model from 
overfitting
     if val_accuracy > new_val_accuracy:
         print('Validation accuracy is decreasing, stopping training')
@@ -385,7 +385,7 @@ Let's see if our network fine-tuned on Caltech101 is up for 
the task:
 
 ```python
 # Number of predictions to show
-TOP_P = 3 
+TOP_P = 3
 ```
 
 
diff --git a/example/gluon/embedding_learning/train.py 
b/example/gluon/embedding_learning/train.py
index 269caff..46f76b5 100644
--- a/example/gluon/embedding_learning/train.py
+++ b/example/gluon/embedding_learning/train.py
@@ -246,7 +246,7 @@ def train(epochs, ctx):
         if val_accs[0] > best_val:
             best_val = val_accs[0]
             logging.info('Saving %s.' % opt.save_model_prefix)
-            net.collect_params().save('%s.params' % opt.save_model_prefix)
+            net.save_params('%s.params' % opt.save_model_prefix)
     return best_val
 
 
diff --git a/example/gluon/kaggle_k_fold_cross_validation.py 
b/example/gluon/kaggle_k_fold_cross_validation.py
index 7911e4d..420e6fc 100644
--- a/example/gluon/kaggle_k_fold_cross_validation.py
+++ b/example/gluon/kaggle_k_fold_cross_validation.py
@@ -88,7 +88,7 @@ def train(net, X_train, y_train, epochs, verbose_epoch, 
learning_rate,
     trainer = gluon.Trainer(net.collect_params(), 'adam',
                             {'learning_rate': learning_rate,
                              'wd': weight_decay})
-    net.collect_params().initialize(force_reinit=True)
+    net.initialize(force_reinit=True)
     for epoch in range(epochs):
         for data, label in data_iter_train:
             with autograd.record():
diff --git a/example/gluon/learning_rate_manipulation.py 
b/example/gluon/learning_rate_manipulation.py
index 1523102..be1ffc2 100644
--- a/example/gluon/learning_rate_manipulation.py
+++ b/example/gluon/learning_rate_manipulation.py
@@ -32,13 +32,13 @@ Y = 2 * X[:, 0] - 3.4 * X[:, 1] + 4.2 + .01 * 
np.random.normal(size=10000)
 net = gluon.nn.Sequential()
 # The output dimension is 1.
 net.add(gluon.nn.Dense(1))
-net.collect_params().initialize()
+net.initialize()
 loss = gluon.loss.L2Loss()
 
 # Initialize the learning rate as 0.1.
 trainer = gluon.Trainer(net.collect_params(), 'sgd',
                         optimizer_params={'learning_rate': 0.1})
-net.collect_params().initialize(mx.init.Xavier(magnitude=2.24),
+net.initialize(mx.init.Xavier(magnitude=2.24),
                                 force_reinit=True)
 train_data = mx.io.NDArrayIter(X, Y, batch_size=10, shuffle=True)
 
@@ -60,4 +60,4 @@ for epoch in range(5):
 
 for para_name, para_value in net.collect_params().items():
     # Print all the parameter values after training.
-    print(para_name, para_value.data().asnumpy()[0])
\ No newline at end of file
+    print(para_name, para_value.data().asnumpy()[0])
diff --git a/example/gluon/lstm_crf.py b/example/gluon/lstm_crf.py
index 857bfca..561b4c2 100644
--- a/example/gluon/lstm_crf.py
+++ b/example/gluon/lstm_crf.py
@@ -197,7 +197,7 @@ for sentence, tags in training_data:
 tag2idx = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}
 
 model = BiLSTM_CRF(len(word2idx), tag2idx, EMBEDDING_DIM, HIDDEN_DIM)
-model.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu())
+model.initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu())
 optimizer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 
0.01, 'wd': 1e-4})
 
 # Check predictions before training
diff --git a/example/gluon/style_transfer/main.py 
b/example/gluon/style_transfer/main.py
index fa21a36..7fcc927 100644
--- a/example/gluon/style_transfer/main.py
+++ b/example/gluon/style_transfer/main.py
@@ -54,7 +54,7 @@ def train(args):
     style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
     if args.resume is not None:
         print('Resuming, initializing using weight from 
{}.'.format(args.resume))
-        style_model.collect_params().load(args.resume, ctx=ctx)
+        style_model.load_params(args.resume, ctx=ctx)
     print('style_model:',style_model)
     # optimizer and loss
     trainer = gluon.Trainer(style_model.collect_params(), 'adam',
@@ -96,7 +96,7 @@ def train(args):
 
                 total_loss = content_loss + style_loss
                 total_loss.backward()
-                
+
             trainer.step(args.batch_size)
             mx.nd.waitall()
 
@@ -112,20 +112,20 @@ def train(args):
                 )
                 print(mesg)
 
-            
+
             if (batch_id + 1) % (4 * args.log_interval) == 0:
                 # save model
                 save_model_filename = "Epoch_" + str(e) + "iters_" + 
str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                     args.content_weight) + "_" + str(args.style_weight) + 
".params"
                 save_model_path = os.path.join(args.save_model_dir, 
save_model_filename)
-                style_model.collect_params().save(save_model_path)
+                style_model.save_params(save_model_path)
                 print("\nCheckpoint, trained model saved at", save_model_path)
 
     # save model
     save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + 
str(time.ctime()).replace(' ', '_') + "_" + str(
         args.content_weight) + "_" + str(args.style_weight) + ".params"
     save_model_path = os.path.join(args.save_model_dir, save_model_filename)
-    style_model.collect_params().save(save_model_path)
+    style_model.save_params(save_model_path)
     print("\nDone, trained model saved at", save_model_path)
 
 
@@ -140,7 +140,7 @@ def evaluate(args):
     style_image = utils.preprocess_batch(style_image)
     # model
     style_model = net.Net(ngf=args.ngf)
-    style_model.collect_params().load(args.model, ctx=ctx)
+    style_model.load_params(args.model, ctx=ctx)
     # forward
     style_model.setTarget(style_image)
     output = style_model(content_image)
@@ -195,7 +195,7 @@ def optimize(args):
         trainer.step(1)
         if (e + 1) % args.log_interval == 0:
             print('loss:{:.2f}'.format(total_loss.asnumpy()[0]))
-        
+
     # save the image
     output = utils.add_imagenet_mean_batch(output.data())
     utils.tensor_save_bgrimage(output[0], args.output_image, args.cuda)
@@ -209,7 +209,7 @@ def main():
         raise ValueError("ERROR: specify the experiment type")
 
     if args.subcommand == "train":
-        # Training the model 
+        # Training the model
         train(args)
 
     elif args.subcommand == 'eval':
diff --git a/example/gluon/super_resolution.py 
b/example/gluon/super_resolution.py
index 7963590..38c3bec 100644
--- a/example/gluon/super_resolution.py
+++ b/example/gluon/super_resolution.py
@@ -144,7 +144,7 @@ def train(epoch, ctx):
         ctx = [ctx]
     net.initialize(mx.init.Orthogonal(), ctx=ctx)
     # re-initialize conv4's weight to be Orthogonal
-    net.conv4.collect_params().initialize(mx.init.Orthogonal(scale=1), 
force_reinit=True, ctx=ctx)
+    net.conv4.initialize(mx.init.Orthogonal(scale=1), force_reinit=True, 
ctx=ctx)
     trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 
opt.lr})
     loss = gluon.loss.L2Loss()
 
diff --git a/example/gluon/tree_lstm/main.py b/example/gluon/tree_lstm/main.py
index 67644f9..d2fe464 100644
--- a/example/gluon/tree_lstm/main.py
+++ b/example/gluon/tree_lstm/main.py
@@ -138,7 +138,7 @@ def test(ctx, data_iter, best, mode='validation', 
num_iter=-1):
         if test_r >= best:
             best = test_r
             logging.info('New optimum found: {}. Checkpointing.'.format(best))
-            
net.collect_params().save('childsum_tree_lstm_{}.params'.format(num_iter))
+            net.save_params('childsum_tree_lstm_{}.params'.format(num_iter))
             test(ctx, test_iter, -1, 'test')
         return best
 
@@ -148,7 +148,7 @@ def train(epoch, ctx, train_data, dev_data):
     # initialization with context
     if isinstance(ctx, mx.Context):
         ctx = [ctx]
-    net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx[0])
+    net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx[0])
     net.embed.weight.set_data(vocab.embed.as_in_context(ctx[0]))
     train_data.set_context(ctx[0])
     dev_data.set_context(ctx[0])
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 6400358..2f8cdd8 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -23,14 +23,14 @@ __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 import copy
 import warnings
 import re
+from collections import OrderedDict
 
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
 from ..ndarray import NDArray
 from .. import name as _name
-from ..context import cpu
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
-from .utils import _indent
+from .utils import _indent, _brief_print_list
 
 
 class _BlockScope(object):
@@ -134,7 +134,6 @@ class Block(object):
             def __init__(self, **kwargs):
                 super(Model, self).__init__(**kwargs)
                 # use name_scope to give child Blocks appropriate names.
-                # It also allows sharing Parameters between Blocks recursively.
                 with self.name_scope():
                     self.dense0 = nn.Dense(20)
                     self.dense1 = nn.Dense(20)
@@ -154,10 +153,11 @@ class Block(object):
     Parameters
     ----------
     prefix : str
-        Prefix acts like a name space. It will be prepended to the names of all
-        Parameters and child :py:class:`Block` s in this :py:class:`Block` 's
-        :py:meth:`name_scope` .
-        Prefix should be unique within one model to prevent name collisions.
+        Prefix acts like a name space. All children blocks created in parent 
block's
+        :py:meth:`name_scope` will have parent block's prefix in their name.
+        Please refer to
+        `naming tutorial 
<http://mxnet.incubator.apache.org/tutorials/basic/naming.html>`_
+        for more info on prefix and naming.
     params : ParameterDict or None
         :py:class:`ParameterDict` for sharing weights with the new 
:py:class:`Block`. For example,
         if you want ``dense1`` to share ``dense0``'s weights, you can do::
@@ -170,15 +170,15 @@ class Block(object):
         self._prefix, self._params = _BlockScope.create(prefix, params, 
self._alias())
         self._name = self._prefix[:-1] if self._prefix.endswith('_') else 
self._prefix
         self._scope = _BlockScope(self)
-        self._children = []
+        self._children = OrderedDict()
+        self._reg_params = {}
 
     def __repr__(self):
         s = '{name}(\n{modstr}\n)'
         modstr = '\n'.join(['  ({key}): {block}'.format(key=key,
                                                         
block=_indent(block.__repr__(), 2))
                             for key, block in self.__dict__.items() if 
isinstance(block, Block)])
-        return s.format(name=self.__class__.__name__,
-                        modstr=modstr)
+        return s.format(name=self.__class__.__name__, modstr=modstr)
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -187,17 +187,17 @@ class Block(object):
             existing = getattr(self, name)
             if isinstance(existing, (Parameter, Block)) and not 
isinstance(value, type(existing)):
                 raise TypeError('Changing attribute type for {name} from 
{type1} to {type2}' \
-                                'is not allowed.'.format(name=name,
-                                                         type1=type(existing),
-                                                         type2=type(value)))
-            if isinstance(existing, Block):
-                for i, c in enumerate(self._children):
-                    if c is existing:
-                        self._children[i] = value
-            elif isinstance(value, Block):
-                self.register_child(value)
-        elif isinstance(value, Block):
-            self.register_child(value)
+                                'is not allowed.'.format(
+                                    name=name, type1=type(existing), 
type2=type(value)))
+
+        if isinstance(value, Block):
+            self.register_child(value, name)
+        elif isinstance(value, Parameter):
+            assert name not in self._reg_params, \
+                "Overriding Parameter attribute %s is not allowed. " \
+                "If you want to share parameters between blocks, please set " \
+                "'params' at Block construction instead."
+            self._reg_params[name] = value
 
         super(Block, self).__setattr__(name, value)
 
@@ -247,6 +247,10 @@ class Block(object):
 
             with self.name_scope():
                 self.dense = nn.Dense(20)
+
+        Please refer to
+        `naming tutorial 
<http://mxnet.incubator.apache.org/tutorials/basic/naming.html>`_
+        for more info on prefix and naming.
         """
         return self._scope
 
@@ -288,19 +292,29 @@ class Block(object):
         else:
             pattern = re.compile(select)
             ret.update({name:value for name, value in self.params.items() if 
pattern.match(name)})
-        for cld in self._children:
+        for cld in self._children.values():
             ret.update(cld.collect_params(select=select))
         return ret
 
+    def _collect_params_with_prefix(self, prefix=''):
+        if prefix:
+            prefix += '.'
+        ret = {prefix + key : val for key, val in self._reg_params.items()}
+        for name, child in self._children.items():
+            ret.update(child._collect_params_with_prefix(prefix + name))
+        return ret
+
     def save_params(self, filename):
         """Save parameters to file.
 
         filename : str
             Path to file.
         """
-        self.collect_params().save(filename, strip_prefix=self.prefix)
+        params = self._collect_params_with_prefix()
+        arg_dict = {key : val._reduce() for key, val in params.items()}
+        ndarray.save(filename, arg_dict)
 
-    def load_params(self, filename, ctx=cpu(), allow_missing=False,
+    def load_params(self, filename, ctx=None, allow_missing=False,
                     ignore_extra=False):
         """Load parameters from file.
 
@@ -314,20 +328,58 @@ class Block(object):
             Whether to silently ignore parameters from the file that are not
             present in this Block.
         """
-        self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
-                                   self.prefix)
+        loaded = ndarray.load(filename)
+        params = self._collect_params_with_prefix()
+        if not loaded and not params:
+            return
 
-    def register_child(self, block):
+        if not any('.' in i for i in loaded.keys()):
+            # legacy loading
+            del loaded
+            self.collect_params().load(
+                filename, ctx, allow_missing, ignore_extra, self.prefix)
+            return
+
+        if not allow_missing:
+            for name in params.keys():
+                assert name in loaded, \
+                    "Parameter '%s' is missing in file '%s', which contains 
parameters: %s. " \
+                    "Set allow_missing=True to ignore missing parameters."%(
+                        name, filename, _brief_print_list(loaded.keys()))
+        for name in loaded:
+            if not ignore_extra and name not in params:
+                raise ValueError(
+                    "Parameter '%s' loaded from file '%s' is not present in 
ParameterDict, " \
+                    "which contains parameters %s. Set ignore_extra=True to 
ignore. "%(
+                        name, filename, 
_brief_print_list(self._params.keys())))
+            params[name]._load_init(loaded[name], ctx)
+
+
+    def register_child(self, block, name=None):
         """Registers block as a child of self. :py:class:`Block` s assigned to 
self as
         attributes will be registered automatically."""
-        self._children.append(block)
+        if name is None:
+            name = str(len(self._children))
+        self._children[name] = block
 
-    def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False):
+    def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
+                   force_reinit=False):
         """Initializes :py:class:`Parameter` s of this :py:class:`Block` and 
its children.
-
         Equivalent to ``block.collect_params().initialize(...)``
+
+        Parameters
+        ----------
+        init : Initializer
+            Global default Initializer to be used when 
:py:meth:`Parameter.init` is ``None``.
+            Otherwise, :py:meth:`Parameter.init` takes precedence.
+        ctx : Context or list of Context
+            Keeps a copy of Parameters on one or many context(s).
+        verbose : bool, default False
+            Whether to verbosely print out details on initialization.
+        force_reinit : bool, default False
+            Whether to force re-initialization if parameter is already 
initialized.
         """
-        self.collect_params().initialize(init, ctx, verbose)
+        self.collect_params().initialize(init, ctx, verbose, force_reinit)
 
     def hybridize(self, active=True, **kwargs):
         """Activates or deactivates :py:class:`HybridBlock` s recursively. Has 
no effect on
@@ -340,7 +392,7 @@ class Block(object):
         **kwargs : string
             Additional flags for hybridized operator.
         """
-        for cld in self._children:
+        for cld in self._children.values():
             cld.hybridize(active, **kwargs)
 
     def cast(self, dtype):
@@ -351,7 +403,7 @@ class Block(object):
         dtype : str or numpy.dtype
             The new data type.
         """
-        for child in self._children:
+        for child in self._children.values():
             child.cast(dtype)
         for _, param in self.params.items():
             param.cast(dtype)
@@ -393,7 +445,6 @@ class HybridBlock(Block):
     """
     def __init__(self, prefix=None, params=None):
         super(HybridBlock, self).__init__(prefix=prefix, params=params)
-        self._reg_params = {}
         self._cached_graph = ()
         self._cached_op = None
         self._cached_op_args = None
@@ -407,13 +458,6 @@ class HybridBlock(Block):
         super(HybridBlock, self).__setattr__(name, value)
         if isinstance(value, HybridBlock):
             self._clear_cached_op()
-        if isinstance(value, Parameter):
-            assert name not in self._reg_params or \
-                not isinstance(self._reg_params[name], Parameter), \
-                "Overriding Parameter attribute %s is not allowed. " \
-                "Please pass in Parameters by specifying `params` at " \
-                "Block construction instead."
-            self._reg_params[name] = value
 
     def _get_graph(self, *args):
         if not self._cached_graph:
@@ -491,14 +535,14 @@ class HybridBlock(Block):
         self._cached_op = None
         self._cached_op_args = None
 
-    def register_child(self, block):
+    def register_child(self, block, name=None):
         if not isinstance(block, HybridBlock):
             raise ValueError(
                 "Children of HybridBlock must also be HybridBlock, " \
                 "but %s has type %s. If you are using Sequential, " \
-                "please try HybridSequential instead"%(
+                "please try HybridSequential instead."%(
                     str(block), str(type(block))))
-        super(HybridBlock, self).register_child(block)
+        super(HybridBlock, self).register_child(block, name)
         self._clear_cached_op()
 
     def hybridize(self, active=True, **kwargs):
diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py 
b/python/mxnet/gluon/contrib/nn/basic_layers.py
index 8870888..eccdf18 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -51,7 +51,7 @@ class Concurrent(Sequential):
 
     def forward(self, x):
         out = []
-        for block in self._children:
+        for block in self._children.values():
             out.append(block(x))
         out = nd.concat(*out, dim=self.axis)
         return out
@@ -84,7 +84,7 @@ class HybridConcurrent(HybridSequential):
 
     def hybrid_forward(self, F, x):
         out = []
-        for block in self._children:
+        for block in self._children.values():
             out.append(block(x))
         out = F.concat(*out, dim=self.axis)
         return out
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index f6113cc..efca0c3 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -49,7 +49,7 @@ class Sequential(Block):
             self.register_child(block)
 
     def forward(self, x):
-        for block in self._children:
+        for block in self._children.values():
             x = block(x)
         return x
 
@@ -57,13 +57,12 @@ class Sequential(Block):
         s = '{name}(\n{modstr}\n)'
         modstr = '\n'.join(['  ({key}): {block}'.format(key=key,
                                                         
block=_indent(block.__repr__(), 2))
-                            for key, block in enumerate(self._children)
-                            if isinstance(block, Block)])
+                            for key, block in self._children.items()])
         return s.format(name=self.__class__.__name__,
                         modstr=modstr)
 
     def __getitem__(self, key):
-        return self._children[key]
+        return self._children[str(key)]
 
     def __len__(self):
         return len(self._children)
@@ -79,9 +78,10 @@ class Sequential(Block):
         **kwargs : string
             Additional flags for hybridized operator.
         """
-        if self._children and all(isinstance(c, HybridBlock) for c in 
self._children):
-            warnings.warn('All children of this Sequential layer are 
HybridBlocks. Consider ' \
-                          'using HybridSequential for the best performance.', 
stacklevel=2)
+        if self._children and all(isinstance(c, HybridBlock) for c in 
self._children.values()):
+            warnings.warn(
+                "All children of this Sequential layer '%s' are HybridBlocks. 
Consider "
+                "using HybridSequential for the best 
performance."%self.prefix, stacklevel=2)
         super(Sequential, self).hybridize(active, **kwargs)
 
 
@@ -106,7 +106,7 @@ class HybridSequential(HybridBlock):
             self.register_child(block)
 
     def hybrid_forward(self, F, x):
-        for block in self._children:
+        for block in self._children.values():
             x = block(x)
         return x
 
@@ -114,13 +114,12 @@ class HybridSequential(HybridBlock):
         s = '{name}(\n{modstr}\n)'
         modstr = '\n'.join(['  ({key}): {block}'.format(key=key,
                                                         
block=_indent(block.__repr__(), 2))
-                            for key, block in enumerate(self._children)
-                            if isinstance(block, Block)])
+                            for key, block in self._children.items()])
         return s.format(name=self.__class__.__name__,
                         modstr=modstr)
 
     def __getitem__(self, key):
-        return self._children[key]
+        return self._children[str(key)]
 
     def __len__(self):
         return len(self._children)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 8d0c5ba..ce82171 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -29,9 +29,9 @@ import numpy as np
 
 from ..base import mx_real_t, MXNetError
 from .. import symbol, ndarray, initializer, context
-from ..context import Context
+from ..context import Context, cpu
 from .. import autograd
-from .utils import _indent
+from .utils import _indent, _brief_print_list
 
 # pylint: disable= invalid-name
 tensor_types = (symbol.Symbol, ndarray.NDArray)
@@ -206,13 +206,16 @@ class Parameter(object):
             ctx = [ctx]
         if self._data is None:
             if self._deferred_init:
-                assert set(ctx) == set(self._deferred_init[1]), \
+                assert ctx is None or set(ctx) == set(self._deferred_init[1]), 
\
                     "Failed to load Parameter '%s' on %s because it was " \
                     "previous initialized on %s."%(
                         self.name, str(ctx), str(self.list_ctx()))
+                ctx = self._deferred_init[1]
+            elif ctx is None:
+                ctx = [cpu()]
             self._init_impl(data, ctx)
         else:
-            assert set(ctx) == set(self.list_ctx()), \
+            assert ctx is None or set(ctx) == set(self.list_ctx()), \
                 "Failed to load Parameter '%s' on %s because it was " \
                 "previous initialized on %s."%(
                     self.name, str(ctx), str(self.list_ctx()))
@@ -497,13 +500,9 @@ class Constant(Parameter):
             name, grad_req='null', shape=value.shape, dtype=value.dtype,
             init=init_name)
 
-
-def _brief_print_list(lst, limit=7):
-    """Print at most `limit` elements of list."""
-    if len(lst) > limit:
-        return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \
-            _brief_print_list(lst[-limit//2:], limit)
-    return ', '.join(["'%s'"%str(i) for i in lst])
+    def __repr__(self):
+        s = 'Constant {name} (shape={shape}, dtype={dtype})'
+        return s.format(name=self.name, shape=self.shape, dtype=self.dtype)
 
 
 class ParameterDict(object):
@@ -677,6 +676,8 @@ class ParameterDict(object):
             Otherwise, :py:meth:`Parameter.init` takes precedence.
         ctx : Context or list of Context
             Keeps a copy of Parameters on one or many context(s).
+        verbose : bool, default False
+            Whether to verbosely print out details on initialization.
         force_reinit : bool, default False
             Whether to force re-initialization if parameter is already 
initialized.
         """
@@ -735,17 +736,17 @@ class ParameterDict(object):
             weight = param._reduce()
             if not param.name.startswith(strip_prefix):
                 raise ValueError(
-                    "Prefix '%s' is to be striped before saving, but Parameter 
" \
-                    "'%s' does not start with '%s'. If you are using 
Block.save_params, " \
-                    "This may be due to your Block shares parameters from 
other " \
-                    "Blocks or you forgot to use ``with name_scope()`` during 
init. " \
-                    "Consider switching to Block.collect_params.save and " \
-                    "Block.collect_params.load instead."%(
+                    "Prefix '%s' is to be striped before saving, but 
Parameter's "
+                    "name '%s' does not start with '%s'. "
+                    "this may be due to your Block shares parameters from 
other "
+                    "Blocks or you forgot to use 'with name_scope()' when 
creating "
+                    "child blocks. For more info on naming, please see "
+                    
"http://mxnet.incubator.apache.org/tutorials/basic/naming.html"%(
                         strip_prefix, param.name, strip_prefix))
             arg_dict[param.name[len(strip_prefix):]] = weight
         ndarray.save(filename, arg_dict)
 
-    def load(self, filename, ctx, allow_missing=False,
+    def load(self, filename, ctx=None, allow_missing=False,
              ignore_extra=False, restore_prefix=''):
         """Load parameters from file.
 
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py 
b/python/mxnet/gluon/rnn/rnn_cell.py
index f5c72f5..281aba4 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -124,7 +124,7 @@ class RecurrentCell(Block):
         """Reset before re-using the cell for another graph."""
         self._init_counter = -1
         self._counter = -1
-        for cell in self._children:
+        for cell in self._children.values():
             cell.reset()
 
     def state_info(self, batch_size=0):
@@ -639,7 +639,7 @@ class SequentialRNNCell(RecurrentCell):
         s = '{name}(\n{modstr}\n)'
         return s.format(name=self.__class__.__name__,
                         modstr='\n'.join(['({i}): {m}'.format(i=i, 
m=_indent(m.__repr__(), 2))
-                                          for i, m in 
enumerate(self._children)]))
+                                          for i, m in self._children.items()]))
 
     def add(self, cell):
         """Appends a cell into the stack.
@@ -652,19 +652,19 @@ class SequentialRNNCell(RecurrentCell):
         self.register_child(cell)
 
     def state_info(self, batch_size=0):
-        return _cells_state_info(self._children, batch_size)
+        return _cells_state_info(self._children.values(), batch_size)
 
     def begin_state(self, **kwargs):
         assert not self._modified, \
             "After applying modifier cells (e.g. ZoneoutCell) the base " \
             "cell cannot be called directly. Call the modifier cell instead."
-        return _cells_begin_state(self._children, **kwargs)
+        return _cells_begin_state(self._children.values(), **kwargs)
 
     def __call__(self, inputs, states):
         self._counter += 1
         next_states = []
         p = 0
-        for cell in self._children:
+        for cell in self._children.values():
             assert not isinstance(cell, BidirectionalCell)
             n = len(cell.state_info())
             state = states[p:p+n]
@@ -683,7 +683,7 @@ class SequentialRNNCell(RecurrentCell):
 
         p = 0
         next_states = []
-        for i, cell in enumerate(self._children):
+        for i, cell in enumerate(self._children.values()):
             n = len(cell.state_info())
             states = begin_state[p:p+n]
             p += n
@@ -696,7 +696,7 @@ class SequentialRNNCell(RecurrentCell):
         return inputs, next_states
 
     def __getitem__(self, i):
-        return self._children[i]
+        return self._children[str(i)]
 
     def __len__(self):
         return len(self._children)
@@ -900,8 +900,8 @@ class BidirectionalCell(HybridRecurrentCell):
     """
     def __init__(self, l_cell, r_cell, output_prefix='bi_'):
         super(BidirectionalCell, self).__init__(prefix='', params=None)
-        self.register_child(l_cell)
-        self.register_child(r_cell)
+        self.register_child(l_cell, 'l_cell')
+        self.register_child(r_cell, 'r_cell')
         self._output_prefix = output_prefix
 
     def __call__(self, inputs, states):
@@ -910,17 +910,17 @@ class BidirectionalCell(HybridRecurrentCell):
     def __repr__(self):
         s = '{name}(forward={l_cell}, backward={r_cell})'
         return s.format(name=self.__class__.__name__,
-                        l_cell=self._children[0],
-                        r_cell=self._children[1])
+                        l_cell=self._children['l_cell'],
+                        r_cell=self._children['r_cell'])
 
     def state_info(self, batch_size=0):
-        return _cells_state_info(self._children, batch_size)
+        return _cells_state_info(self._children.values(), batch_size)
 
     def begin_state(self, **kwargs):
         assert not self._modified, \
             "After applying modifier cells (e.g. DropoutCell) the base " \
             "cell cannot be called directly. Call the modifier cell instead."
-        return _cells_begin_state(self._children, **kwargs)
+        return _cells_begin_state(self._children.values(), **kwargs)
 
     def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
                valid_length=None):
@@ -938,7 +938,7 @@ class BidirectionalCell(HybridRecurrentCell):
         begin_state = _get_begin_state(self, F, begin_state, inputs, 
batch_size)
 
         states = begin_state
-        l_cell, r_cell = self._children
+        l_cell, r_cell = self._children.values()
         l_outputs, l_states = l_cell.unroll(length, inputs=inputs,
                                             
begin_state=states[:len(l_cell.state_info(batch_size))],
                                             layout=layout, 
merge_outputs=merge_outputs,
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index cb784b7..7dd2a1a 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -242,3 +242,10 @@ def _get_repo_file_url(namespace, filename):
     return '{base_url}{namespace}/{filename}'.format(base_url=_get_repo_url(),
                                                      namespace=namespace,
                                                      filename=filename)
+
+def _brief_print_list(lst, limit=7):
+    """Print at most `limit` elements of list."""
+    if len(lst) > limit:
+        return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \
+            _brief_print_list(lst[-limit//2:], limit)
+    return ', '.join(["'%s'"%str(i) for i in lst])
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index d91b3f0..ca1e121 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -557,7 +557,7 @@ def test_block_attr_regular():
     b.c = gluon.Block()
     c2 = gluon.Block()
     b.c = c2
-    assert b.c is c2 and b._children[0] is c2
+    assert b.c is c2 and list(b._children.values())[0] is c2
 
 
 @with_seed()
@@ -589,18 +589,22 @@ def test_block_attr_list_of_block():
                 self.data = {'a': '4', 'b': 123}
 
     with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter('always')
         model = Model1()
         model.collect_params()
         assert len(w) > 0
     with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter('always')
         model = Model2()
         model.collect_params()
         assert len(w) > 0
     with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter('always')
         model = Model3()
         model.collect_params()
         assert len(w) == 0
     with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter('always')
         model = Model4()
         model.collect_params()
         assert len(w) == 0
@@ -882,6 +886,14 @@ def test_dropout():
         check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
 
 
+def test_save_load():
+    net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
+    net.save_params('test.params')
+
+    net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
+    net.output = mx.gluon.nn.Dense(1000)
+
+    net.load_params('test.params')
 
 
 if __name__ == '__main__':

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to