piiswrong closed pull request #11127: add import_ for SymbolBlock
URL: https://github.com/apache/incubator-mxnet/pull/11127
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/docs/tutorials/gluon/hybrid.md b/docs/tutorials/gluon/hybrid.md
index 3554a15fa3b..5c8372a51f4 100644
--- a/docs/tutorials/gluon/hybrid.md
+++ b/docs/tutorials/gluon/hybrid.md
@@ -117,7 +117,7 @@ x = mx.sym.var('data')
y = net(x)
print(y)
y.save('model.json')
-net.save_params('model.params')
+net.save_parameters('model.params')
```
If your network outputs more than one value, you can use `mx.sym.Group` to
diff --git a/docs/tutorials/gluon/naming.md b/docs/tutorials/gluon/naming.md
index 37b63fa08a9..3606a03dcbd 100644
--- a/docs/tutorials/gluon/naming.md
+++ b/docs/tutorials/gluon/naming.md
@@ -203,12 +203,12 @@ except Exception as e:
Parameter 'model1_dense0_weight' is missing in file 'model.params', which
contains parameters: 'model0_mydense_weight', 'model0_dense1_bias',
'model0_dense1_weight', 'model0_dense0_weight', 'model0_dense0_bias',
'model0_mydense_bias'. Please make sure source and target networks have the
same prefix.
-To solve this problem, we use `save_params`/`load_params` instead of
`collect_params` and `save`/`load`. `save_params` uses model structure, instead
of parameter name, to match parameters.
+To solve this problem, we use `save_parameters`/`load_parameters` instead of
`collect_params` and `save`/`load`. `save_parameters` uses model structure,
instead of parameter name, to match parameters.
```python
-model0.save_params('model.params')
-model1.load_params('model.params')
+model0.save_parameters('model.params')
+model1.load_parameters('model.params')
print(mx.nd.load('model.params').keys())
```
diff --git a/docs/tutorials/gluon/save_load_params.md
b/docs/tutorials/gluon/save_load_params.md
index cd876808a86..f5f48125cc1 100644
--- a/docs/tutorials/gluon/save_load_params.md
+++ b/docs/tutorials/gluon/save_load_params.md
@@ -10,7 +10,7 @@ Parameters of any Gluon model can be saved using the
`save_params` and `load_par
**2. Save/load model parameters AND architecture**
-The Model architecture of `Hybrid` models stays static and don't change during
execution. Therefore both model parameters AND architecture can be saved and
loaded using `export`, `load_checkpoint` and `load` methods.
+The Model architecture of `Hybrid` models stays static and don't change during
execution. Therefore both model parameters AND architecture can be saved and
loaded using `export`, `imports` methods.
Let's look at the above methods in more detail. Let's start by importing the
modules we'll need.
@@ -61,7 +61,7 @@ def build_lenet(net):
net.add(gluon.nn.Dense(512, activation="relu"))
# Second fully connected layer with as many neurons as the number of
classes
net.add(gluon.nn.Dense(num_outputs))
-
+
return net
# Train a given model using MNIST data
@@ -240,18 +240,10 @@ One of the main reasons to serialize model architecture
into a JSON file is to l
### From Python
-Serialized Hybrid networks (saved as .JSON and .params file) can be loaded and
used inside Python frontend using `mx.model.load_checkpoint` and
`gluon.nn.SymbolBlock`. To demonstrate that, let's load the network we
serialized above.
+Serialized Hybrid networks (saved as .JSON and .params file) can be loaded and
used inside Python frontend using `gluon.nn.SymbolBlock`. To demonstrate that,
let's load the network we serialized above.
```python
-# Load the network architecture and parameters
-sym = mx.sym.load('lenet-symbol.json')
-# Create a Gluon Block using the loaded network architecture.
-# 'inputs' parameter specifies the name of the symbol in the computation graph
-# that should be treated as input. 'data' is the default name used for input
when
-# a model architecture is saved to a file.
-deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
-# Load the parameters
-deserialized_net.collect_params().load('lenet-0001.params', ctx=ctx)
+deserialized_net = gluon.nn.SymbolBlock.imports("lenet-symbol.json", ['data'],
"lenet-0001.params")
```
`deserialized_net` now contains the network we deserialized from files. Let's
test the deserialized network to make sure it works.
diff --git a/example/gluon/dcgan.py b/example/gluon/dcgan.py
index 3233f430eea..8ac9c522cf5 100644
--- a/example/gluon/dcgan.py
+++ b/example/gluon/dcgan.py
@@ -229,8 +229,8 @@ def transformer(data, label):
logging.info('time: %f' % (time.time() - tic))
if check_point:
- netG.save_params(os.path.join(outf,'generator_epoch_%d.params' %epoch))
- netD.save_params(os.path.join(outf,'discriminator_epoch_%d.params' %
epoch))
+ netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params'
%epoch))
+ netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params'
% epoch))
-netG.save_params(os.path.join(outf, 'generator.params'))
-netD.save_params(os.path.join(outf, 'discriminator.params'))
+netG.save_parameters(os.path.join(outf, 'generator.params'))
+netD.save_parameters(os.path.join(outf, 'discriminator.params'))
diff --git a/example/gluon/embedding_learning/train.py
b/example/gluon/embedding_learning/train.py
index 46f76b55614..b8a5bf2716c 100644
--- a/example/gluon/embedding_learning/train.py
+++ b/example/gluon/embedding_learning/train.py
@@ -246,7 +246,7 @@ def train(epochs, ctx):
if val_accs[0] > best_val:
best_val = val_accs[0]
logging.info('Saving %s.' % opt.save_model_prefix)
- net.save_params('%s.params' % opt.save_model_prefix)
+ net.save_parameters('%s.params' % opt.save_model_prefix)
return best_val
diff --git a/example/gluon/image_classification.py
b/example/gluon/image_classification.py
index 6e2f1d6a78d..b21e943f17f 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -122,7 +122,7 @@ def get_model(model, ctx, opt):
net = models.get_model(model, **kwargs)
if opt.resume:
- net.load_params(opt.resume)
+ net.load_parameters(opt.resume)
elif not opt.use_pretrained:
if model in ['alexnet']:
net.initialize(mx.init.Normal())
@@ -176,12 +176,12 @@ def update_learning_rate(lr, trainer, epoch, ratio,
steps):
def save_checkpoint(epoch, top1, best_acc):
if opt.save_frequency and (epoch + 1) % opt.save_frequency == 0:
fname = os.path.join(opt.prefix, '%s_%d_acc_%.4f.params' % (opt.model,
epoch, top1))
- net.save_params(fname)
+ net.save_parameters(fname)
logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f',
epoch, fname, top1)
if top1 > best_acc[0]:
best_acc[0] = top1
fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model))
- net.save_params(fname)
+ net.save_parameters(fname)
logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f',
epoch, fname, top1)
def train(opt, ctx):
@@ -267,7 +267,7 @@ def main():
optimizer = 'sgd',
optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd,
'momentum': opt.momentum, 'multi_precision': True},
initializer = mx.init.Xavier(magnitude=2))
- mod.save_params('image-classifier-%s-%d-final.params'%(opt.model,
opt.epochs))
+ mod.save_parameters('image-classifier-%s-%d-final.params'%(opt.model,
opt.epochs))
else:
if opt.mode == 'hybrid':
net.hybridize()
diff --git a/example/gluon/mnist.py b/example/gluon/mnist.py
index 198d7ca5ab2..6aea3abc504 100644
--- a/example/gluon/mnist.py
+++ b/example/gluon/mnist.py
@@ -117,7 +117,7 @@ def train(epochs, ctx):
name, val_acc = test(ctx)
print('[Epoch %d] Validation: %s=%f'%(epoch, name, val_acc))
- net.save_params('mnist.params')
+ net.save_parameters('mnist.params')
if __name__ == '__main__':
diff --git a/example/gluon/style_transfer/main.py
b/example/gluon/style_transfer/main.py
index cab8211bc9c..dde992ae700 100644
--- a/example/gluon/style_transfer/main.py
+++ b/example/gluon/style_transfer/main.py
@@ -55,7 +55,7 @@ def train(args):
style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
if args.resume is not None:
print('Resuming, initializing using weight from
{}.'.format(args.resume))
- style_model.load_params(args.resume, ctx=ctx)
+ style_model.load_parameters(args.resume, ctx=ctx)
print('style_model:',style_model)
# optimizer and loss
trainer = gluon.Trainer(style_model.collect_params(), 'adam',
@@ -121,14 +121,14 @@ def train(args):
str(count) + "_" + str(time.ctime()).replace(' ', '_') +
"_" + str(
args.content_weight) + "_" + str(args.style_weight) +
".params"
save_model_path = os.path.join(args.save_model_dir,
save_model_filename)
- style_model.save_params(save_model_path)
+ style_model.save_parameters(save_model_path)
print("\nCheckpoint, trained model saved at", save_model_path)
# save model
save_model_filename = "Final_epoch_" + str(args.epochs) + "_" +
str(time.ctime()).replace(' ', '_') + "_" + str(
args.content_weight) + "_" + str(args.style_weight) + ".params"
save_model_path = os.path.join(args.save_model_dir, save_model_filename)
- style_model.save_params(save_model_path)
+ style_model.save_parameters(save_model_path)
print("\nDone, trained model saved at", save_model_path)
@@ -143,7 +143,7 @@ def evaluate(args):
style_image = utils.preprocess_batch(style_image)
# model
style_model = net.Net(ngf=args.ngf)
- style_model.load_params(args.model, ctx=ctx)
+ style_model.load_parameters(args.model, ctx=ctx)
# forward
style_model.set_target(style_image)
output = style_model(content_image)
diff --git a/example/gluon/super_resolution.py
b/example/gluon/super_resolution.py
index 38c3bec8949..0f2f21f3c0a 100644
--- a/example/gluon/super_resolution.py
+++ b/example/gluon/super_resolution.py
@@ -168,13 +168,13 @@ def train(epoch, ctx):
print('training mse at epoch %d: %s=%f'%(i, name, acc))
test(ctx)
- net.save_params('superres.params')
+ net.save_parameters('superres.params')
def resolve(ctx):
from PIL import Image
if isinstance(ctx, list):
ctx = [ctx[0]]
- net.load_params('superres.params', ctx=ctx)
+ net.load_parameters('superres.params', ctx=ctx)
img = Image.open(opt.resolve_img).convert('YCbCr')
y, cb, cr = img.split()
data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
diff --git a/example/gluon/tree_lstm/main.py b/example/gluon/tree_lstm/main.py
index d2fe464638a..ad5d59f7a47 100644
--- a/example/gluon/tree_lstm/main.py
+++ b/example/gluon/tree_lstm/main.py
@@ -138,7 +138,7 @@ def test(ctx, data_iter, best, mode='validation',
num_iter=-1):
if test_r >= best:
best = test_r
logging.info('New optimum found: {}. Checkpointing.'.format(best))
- net.save_params('childsum_tree_lstm_{}.params'.format(num_iter))
+
net.save_parameters('childsum_tree_lstm_{}.params'.format(num_iter))
test(ctx, test_iter, -1, 'test')
return best
diff --git a/example/gluon/word_language_model/train.py
b/example/gluon/word_language_model/train.py
index 9e152636bb0..7f0a916b79b 100644
--- a/example/gluon/word_language_model/train.py
+++ b/example/gluon/word_language_model/train.py
@@ -185,7 +185,7 @@ def train():
if val_L < best_val:
best_val = val_L
test_L = eval(test_data)
- model.save_params(args.save)
+ model.save_parameters(args.save)
print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
else:
args.lr = args.lr*0.25
@@ -193,6 +193,6 @@ def train():
if __name__ == '__main__':
train()
- model.load_params(args.save, context)
+ model.load_parameters(args.save, context)
test_L = eval(test_data)
print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index f107da3c8da..77c6f88f111 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -16,7 +16,7 @@
# under the License.
# coding: utf-8
-# pylint: disable= arguments-differ
+# pylint: disable= arguments-differ, too-many-lines
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
@@ -307,7 +307,7 @@ def _collect_params_with_prefix(self, prefix=''):
ret.update(child._collect_params_with_prefix(prefix + name))
return ret
- def save_params(self, filename):
+ def save_parameters(self, filename):
"""Save parameters to file.
filename : str
@@ -317,8 +317,23 @@ def save_params(self, filename):
arg_dict = {key : val._reduce() for key, val in params.items()}
ndarray.save(filename, arg_dict)
- def load_params(self, filename, ctx=None, allow_missing=False,
- ignore_extra=False):
+ def save_params(self, filename):
+ """[Deprecated] Please use save_parameters.
+
+ Save parameters to file.
+
+ filename : str
+ Path to file.
+ """
+ warnings.warn("save_params is deprecated. Please use save_parameters.")
+ try:
+ self.collect_params().save(filename, strip_prefix=self.prefix)
+ except ValueError as e:
+ raise ValueError('%s\nsave_params is deprecated. Using ' \
+ 'save_parameters may resolve this
error.'%e.message)
+
+ def load_parameters(self, filename, ctx=None, allow_missing=False,
+ ignore_extra=False):
"""Load parameters from file.
filename : str
@@ -357,6 +372,25 @@ def load_params(self, filename, ctx=None,
allow_missing=False,
name, filename,
_brief_print_list(self._params.keys())))
params[name]._load_init(loaded[name], ctx)
+ def load_params(self, filename, ctx=None, allow_missing=False,
+ ignore_extra=False):
+ """[Deprecated] Please use load_parameters.
+
+ Load parameters from file.
+
+ filename : str
+ Path to parameter file.
+ ctx : Context or list of Context, default cpu()
+ Context(s) initialize loaded parameters on.
+ allow_missing : bool, default False
+ Whether to silently skip loading parameters not represents in the
file.
+ ignore_extra : bool, default False
+ Whether to silently ignore parameters from the file that are not
+ present in this Block.
+ """
+ warnings.warn("load_params is deprecated. Please use load_parameters.")
+ self.load_parameters(filename, ctx, allow_missing, ignore_extra)
+
def register_child(self, block, name=None):
"""Registers block as a child of self. :py:class:`Block` s assigned to
self as
attributes will be registered automatically."""
@@ -770,8 +804,8 @@ def infer_type(self, *args):
self._infer_attrs('infer_type', 'dtype', *args)
def export(self, path, epoch=0):
- """Export HybridBlock to json format that can be loaded by
`mxnet.mod.Module`
- or the C++ interface.
+ """Export HybridBlock to json format that can be loaded by
+ `SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface.
.. note:: When there are only one input, it will have name `data`.
When there
Are more than one inputs, they will be named as `data0`,
`data1`, etc.
@@ -885,6 +919,50 @@ class SymbolBlock(HybridBlock):
>>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> print(feat_model(x))
"""
+ @staticmethod
+ def imports(symbol_file, input_names, param_file=None, ctx=None):
+ """Import model previously saved by `HybridBlock.export` or
+ `Module.save_checkpoint` as a SymbolBlock for use in Gluon.
+
+ Parameters
+ ----------
+ symbol_file : str
+ Path to symbol file.
+ input_names : list of str
+ List of input variable names
+ param_file : str, optional
+ Path to parameter file.
+ ctx : Context, default None
+ The context to initialize SymbolBlock on.
+
+ Returns
+ -------
+ SymbolBlock
+ SymbolBlock loaded from symbol and parameter files.
+
+ Examples
+ --------
+ >>> net1 = gluon.model_zoo.vision.resnet18_v1(
+ ... prefix='resnet', pretrained=True)
+ >>> net1.hybridize()
+ >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
+ >>> out1 = net1(x)
+ >>> net1.export('net1', epoch=1)
+ >>>
+ >>> net2 = gluon.SymbolBlock.imports(
+ ... 'net1-symbol.json', ['data'], 'net1-0001.params')
+ >>> out2 = net2(x)
+ """
+ sym = symbol.load(symbol_file)
+ if isinstance(input_names, str):
+ input_names = [input_names]
+ inputs = [symbol.var(i) for i in input_names]
+ ret = SymbolBlock(sym, inputs)
+ if param_file is not None:
+ ret.collect_params().load(param_file, ctx=ctx)
+ return ret
+
+
def __init__(self, outputs, inputs, params=None):
super(SymbolBlock, self).__init__(prefix=None, params=None)
self._prefix = ''
diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py
b/python/mxnet/gluon/model_zoo/vision/alexnet.py
index 55499470460..fdb006258c2 100644
--- a/python/mxnet/gluon/model_zoo/vision/alexnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py
@@ -83,5 +83,5 @@ def alexnet(pretrained=False, ctx=cpu(),
net = AlexNet(**kwargs)
if pretrained:
from ..model_store import get_model_file
- net.load_params(get_model_file('alexnet', root=root), ctx=ctx)
+ net.load_parameters(get_model_file('alexnet', root=root), ctx=ctx)
return net
diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py
b/python/mxnet/gluon/model_zoo/vision/densenet.py
index 835336739a6..b03f5ce8d52 100644
--- a/python/mxnet/gluon/model_zoo/vision/densenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/densenet.py
@@ -141,7 +141,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
if pretrained:
from ..model_store import get_model_file
- net.load_params(get_model_file('densenet%d'%(num_layers), root=root),
ctx=ctx)
+ net.load_parameters(get_model_file('densenet%d'%(num_layers),
root=root), ctx=ctx)
return net
def densenet121(**kwargs):
diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py
b/python/mxnet/gluon/model_zoo/vision/inception.py
index 6d75050b83f..7c54691f1b5 100644
--- a/python/mxnet/gluon/model_zoo/vision/inception.py
+++ b/python/mxnet/gluon/model_zoo/vision/inception.py
@@ -216,5 +216,5 @@ def inception_v3(pretrained=False, ctx=cpu(),
net = Inception3(**kwargs)
if pretrained:
from ..model_store import get_model_file
- net.load_params(get_model_file('inceptionv3', root=root), ctx=ctx)
+ net.load_parameters(get_model_file('inceptionv3', root=root), ctx=ctx)
return net
diff --git a/python/mxnet/gluon/model_zoo/vision/mobilenet.py
b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
index 5b4c9a8e615..1a2c9b94619 100644
--- a/python/mxnet/gluon/model_zoo/vision/mobilenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
@@ -213,7 +213,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
version_suffix = '{0:.2f}'.format(multiplier)
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
- net.load_params(
+ net.load_parameters(
get_model_file('mobilenet%s' % version_suffix, root=root), ctx=ctx)
return net
@@ -245,7 +245,7 @@ def get_mobilenet_v2(multiplier, pretrained=False,
ctx=cpu(),
version_suffix = '{0:.2f}'.format(multiplier)
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
- net.load_params(
+ net.load_parameters(
get_model_file('mobilenetv2_%s' % version_suffix, root=root),
ctx=ctx)
return net
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py
b/python/mxnet/gluon/model_zoo/vision/resnet.py
index 5ee67b510a8..da279b89583 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -386,8 +386,8 @@ def get_resnet(version, num_layers, pretrained=False,
ctx=cpu(),
net = resnet_class(block_class, layers, channels, **kwargs)
if pretrained:
from ..model_store import get_model_file
- net.load_params(get_model_file('resnet%d_v%d'%(num_layers, version),
- root=root), ctx=ctx)
+ net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers,
version),
+ root=root), ctx=ctx)
return net
def resnet18_v1(**kwargs):
diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
index 09f62a52074..aaff4c36dfa 100644
--- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
@@ -132,7 +132,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(),
net = SqueezeNet(version, **kwargs)
if pretrained:
from ..model_store import get_model_file
- net.load_params(get_model_file('squeezenet%s'%version, root=root),
ctx=ctx)
+ net.load_parameters(get_model_file('squeezenet%s'%version, root=root),
ctx=ctx)
return net
def squeezenet1_0(**kwargs):
diff --git a/python/mxnet/gluon/model_zoo/vision/vgg.py
b/python/mxnet/gluon/model_zoo/vision/vgg.py
index dbae5385898..a3b1685b413 100644
--- a/python/mxnet/gluon/model_zoo/vision/vgg.py
+++ b/python/mxnet/gluon/model_zoo/vision/vgg.py
@@ -114,8 +114,8 @@ def get_vgg(num_layers, pretrained=False, ctx=cpu(),
if pretrained:
from ..model_store import get_model_file
batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
- net.load_params(get_model_file('vgg%d%s'%(num_layers,
batch_norm_suffix),
- root=root), ctx=ctx)
+ net.load_parameters(get_model_file('vgg%d%s'%(num_layers,
batch_norm_suffix),
+ root=root), ctx=ctx)
return net
def vgg11(**kwargs):
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index ced3063448b..71e65ec6810 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -202,20 +202,20 @@ def forward(self, x):
net1.collect_params().initialize()
net2(mx.nd.zeros((3, 5)))
- net1.save_params('net1.params')
+ net1.save_parameters('net1.params')
net3 = Net(prefix='net3_')
- net3.load_params('net1.params', mx.cpu())
+ net3.load_parameters('net1.params', mx.cpu())
net4 = Net(prefix='net4_')
net5 = Net(prefix='net5_', in_units=5, params=net4.collect_params())
net4.collect_params().initialize()
net5(mx.nd.zeros((3, 5)))
- net4.save_params('net4.params')
+ net4.save_parameters('net4.params')
net6 = Net(prefix='net6_')
- net6.load_params('net4.params', mx.cpu())
+ net6.load_parameters('net4.params', mx.cpu())
@with_seed()
@@ -776,7 +776,7 @@ def test_export():
model = gluon.model_zoo.vision.resnet18_v1(
prefix='resnet', ctx=ctx, pretrained=True)
model.hybridize()
- data = mx.nd.random.normal(shape=(1, 3, 224, 224))
+ data = mx.nd.random.normal(shape=(1, 3, 32, 32))
out = model(data)
model.export('gluon')
@@ -794,6 +794,22 @@ def test_export():
assert_almost_equal(out.asnumpy(), out2.asnumpy())
+@with_seed()
+def test_import():
+ ctx = mx.context.current_context()
+ net1 = gluon.model_zoo.vision.resnet18_v1(
+ prefix='resnet', ctx=ctx, pretrained=True)
+ net1.hybridize()
+ data = mx.nd.random.normal(shape=(1, 3, 32, 32))
+ out1 = net1(data)
+
+ net1.export('net1', epoch=1)
+
+ net2 = gluon.SymbolBlock.imports(
+ 'net1-symbol.json', ['data'], 'net1-0001.params', ctx)
+ out2 = net2(data)
+
+ assert_almost_equal(out1.asnumpy(), out2.asnumpy())
@with_seed()
def test_hybrid_stale_cache():
@@ -910,7 +926,7 @@ def test_fill_shape_load():
net1.hybridize()
net1.initialize(ctx=ctx)
net1(mx.nd.ones((2,3,5,7), ctx))
- net1.save_params('net_fill.params')
+ net1.save_parameters('net_fill.params')
net2 = nn.HybridSequential()
with net2.name_scope():
@@ -919,7 +935,7 @@ def test_fill_shape_load():
nn.Dense(10))
net2.hybridize()
net2.initialize()
- net2.load_params('net_fill.params', ctx)
+ net2.load_parameters('net_fill.params', ctx)
assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
@@ -1065,12 +1081,12 @@ def test_req():
@with_seed()
def test_save_load():
net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
- net.save_params('test_save_load.params')
+ net.save_parameters('test_save_load.params')
net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
net.output = mx.gluon.nn.Dense(1000)
- net.load_params('test_save_load.params')
+ net.load_parameters('test_save_load.params')
@with_seed()
def test_symbol_block_save_load():
@@ -1095,10 +1111,10 @@ def hybrid_forward(self, F, x):
net1.initialize(mx.init.Normal())
net1.hybridize()
net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
- net1.save_params('./test_symbol_block_save_load.params')
+ net1.save_parameters('./test_symbol_block_save_load.params')
net2 = Net()
- net2.load_params('./test_symbol_block_save_load.params', ctx=mx.cpu())
+ net2.load_parameters('./test_symbol_block_save_load.params', ctx=mx.cpu())
@with_seed()
@@ -1252,6 +1268,22 @@ def test_summary():
assert_raises(AssertionError, net.summary, mx.nd.ones((32, 3, 224, 224)))
+@with_seed()
+def test_legacy_save_params():
+ net = gluon.nn.HybridSequential(prefix='')
+ with net.name_scope():
+ net.add(gluon.nn.Conv2D(10, (3, 3)))
+ net.add(gluon.nn.Dense(50))
+ net.initialize()
+ net(mx.nd.ones((1,1,50,50)))
+ a = net(mx.sym.var('data'))
+ a.save('test.json')
+ net.save_params('test.params')
+ model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json',
'r').read()),
+ inputs=mx.sym.var('data'))
+ model.load_params('test.params', ctx=mx.cpu())
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services