piiswrong closed pull request #9151: removes python path insert of tests folder
for examples
URL: https://github.com/apache/incubator-mxnet/pull/9151
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/example/adversary/adversary_generation.ipynb
b/example/adversary/adversary_generation.ipynb
index 8adae3d3b7..b8804bd813 100644
--- a/example/adversary/adversary_generation.ipynb
+++ b/example/adversary/adversary_generation.ipynb
@@ -28,10 +28,7 @@
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm as cm\n",
"\n",
- "import os\n",
- "import sys\n",
- "sys.path.append(os.path.join(os.getcwd(),
\"../../tests/python/common\"))\n",
- "from get_data import MNISTIterator"
+ "from mxnet.test_utils import get_mnist_iterator"
]
},
{
@@ -53,7 +50,7 @@
"source": [
"dev = mx.cpu()\n",
"batch_size = 100\n",
- "train_iter, val_iter = mnist_iterator(batch_size=batch_size, input_shape
= (1,28,28))"
+ "train_iter, val_iter = get_mnist_iterator(batch_size=batch_size,
input_shape = (1,28,28))"
]
},
{
diff --git a/example/caffe/data.py b/example/caffe/data.py
index fac8e11989..15276c4236 100644
--- a/example/caffe/data.py
+++ b/example/caffe/data.py
@@ -15,19 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-import os
-# code to automatically download dataset
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-import get_data
import mxnet as mx
+from mxnet.test_utils import get_mnist_ubyte
def get_iterator(data_shape, use_caffe_data):
def get_iterator_impl_mnist(args, kv):
"""return train and val iterators for mnist"""
# download data
- get_data.GetMNIST_ubyte()
+ get_mnist_ubyte()
flat = False if len(data_shape) != 1 else True
train = mx.io.MNISTIter(
diff --git a/example/gluon/data.py b/example/gluon/data.py
index 67519e6a20..dc8f12e81f 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -19,39 +19,11 @@
""" data iterator for mnist """
import os
import random
-import sys
-# code to automatically download dataset
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-import get_data
import mxnet as mx
+from mxnet.test_utils import get_cifar10
-def mnist_iterator(batch_size, input_shape):
- """return train and val iterators for mnist"""
- # download data
- get_data.GetMNIST_ubyte()
- flat = False if len(input_shape) == 3 else True
-
- train_dataiter = mx.io.MNISTIter(
- image="data/train-images-idx3-ubyte",
- label="data/train-labels-idx1-ubyte",
- input_shape=input_shape,
- batch_size=batch_size,
- shuffle=True,
- flat=flat)
-
- val_dataiter = mx.io.MNISTIter(
- image="data/t10k-images-idx3-ubyte",
- label="data/t10k-labels-idx1-ubyte",
- input_shape=input_shape,
- batch_size=batch_size,
- flat=flat)
-
- return (train_dataiter, val_dataiter)
-
-
-def cifar10_iterator(batch_size, data_shape, resize=-1):
- get_data.GetCifar10()
+def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1,
part_index=0):
+ get_cifar10()
train = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/train.rec",
@@ -60,7 +32,9 @@ def cifar10_iterator(batch_size, data_shape, resize=-1):
data_shape = data_shape,
batch_size = batch_size,
rand_crop = True,
- rand_mirror = True)
+ rand_mirror = True,
+ num_parts=num_parts,
+ part_index=part_index)
val = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/test.rec",
@@ -69,11 +43,14 @@ def cifar10_iterator(batch_size, data_shape, resize=-1):
rand_crop = False,
rand_mirror = False,
data_shape = data_shape,
- batch_size = batch_size)
+ batch_size = batch_size,
+ num_parts=num_parts,
+ part_index=part_index)
return train, val
-def imagenet_iterator(train_data, val_data, batch_size, data_shape, resize=-1):
+
+def get_imagenet_iterator(train_data, val_data, batch_size, data_shape,
resize=-1, num_parts=1, part_index=0):
train = mx.io.ImageRecordIter(
path_imgrec = train_data,
data_shape = data_shape,
@@ -96,7 +73,9 @@ def imagenet_iterator(train_data, val_data, batch_size,
data_shape, resize=-1):
max_random_shear_ratio = 0.1,
max_random_aspect_ratio = 0.25,
fill_value = 127,
- min_random_scale = 0.533)
+ min_random_scale = 0.533,
+ num_parts = num_parts,
+ part_index = part_index)
val = mx.io.ImageRecordIter(
path_imgrec = val_data,
@@ -109,7 +88,9 @@ def imagenet_iterator(train_data, val_data, batch_size,
data_shape, resize=-1):
std_b = 57.375,
preprocess_threads = 32,
batch_size = batch_size,
- resize = resize)
+ resize = resize,
+ num_parts = num_parts,
+ part_index = part_index)
return train, val
diff --git a/example/gluon/image_classification.py
b/example/gluon/image_classification.py
index a2fb757683..529b977a79 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -26,12 +26,13 @@
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import vision as models
from mxnet import autograd as ag
+from mxnet.test_utils import get_mnist_iterator
from data import *
# CLI
parser = argparse.ArgumentParser(description='Train a model for image
classification.')
-parser.add_argument('--dataset', type=str, default='mnist',
+parser.add_argument('--dataset', type=str, default='cifar10',
help='dataset to use. options are mnist, cifar10, and
dummy.')
parser.add_argument('--train-data', type=str, default='',
help='training record file to use, required for imagenet.')
@@ -92,25 +93,31 @@
net = models.get_model(opt.model, **kwargs)
-# get dataset iterators
-if dataset == 'mnist':
- train_data, val_data = mnist_iterator(batch_size, (1, 32, 32))
-elif dataset == 'cifar10':
- train_data, val_data = cifar10_iterator(batch_size, (3, 32, 32))
-elif dataset == 'imagenet':
- if model_name == 'inceptionv3':
- train_data, val_data = imagenet_iterator(opt.train_data, opt.val_data,
- batch_size, (3, 299, 299))
- else:
- train_data, val_data = imagenet_iterator(opt.train_data, opt.val_data,
- batch_size, (3, 224, 224))
-elif dataset == 'dummy':
- if model_name == 'inceptionv3':
- train_data, val_data = dummy_iterator(batch_size, (3, 299, 299))
- else:
- train_data, val_data = dummy_iterator(batch_size, (3, 224, 224))
-
-def test(ctx):
+def get_data_iters(dataset, batch_size, num_workers=1, rank=0):
+ # get dataset iterators
+ if dataset == 'mnist':
+ train_data, val_data = get_mnist_iterator(batch_size, (1, 28, 28),
+ num_parts=num_workers,
part_index=rank)
+ elif dataset == 'cifar10':
+ train_data, val_data = get_cifar10_iterator(batch_size, (3, 32, 32),
+ num_parts=num_workers,
part_index=rank)
+ elif dataset == 'imagenet':
+ if model_name == 'inceptionv3':
+ train_data, val_data = get_imagenet_iterator(opt.train_data,
opt.val_data,
+ batch_size, (3, 299,
299),
+
num_parts=num_workers, part_index=rank)
+ else:
+ train_data, val_data = get_imagenet_iterator(opt.train_data,
opt.val_data,
+ batch_size, (3, 224,
224),
+
num_parts=num_workers, part_index=rank)
+ elif dataset == 'dummy':
+ if model_name == 'inceptionv3':
+ train_data, val_data = dummy_iterator(batch_size, (3, 299, 299))
+ else:
+ train_data, val_data = dummy_iterator(batch_size, (3, 224, 224))
+ return train_data, val_data
+
+def test(ctx, val_data):
metric = mx.metric.Accuracy()
val_data.reset()
for batch in val_data:
@@ -127,9 +134,11 @@ def train(epochs, ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
+ kv = mx.kv.create(opt.kvstore)
+ train_data, val_data = get_data_iters(dataset, batch_size, kv.num_workers,
kv.rank)
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': opt.lr, 'wd': opt.wd,
'momentum': opt.momentum},
- kvstore = opt.kvstore)
+ kvstore = kv)
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()
@@ -164,7 +173,7 @@ def train(epochs, ctx):
name, acc = metric.get()
logging.info('[Epoch %d] training: %s=%f'%(epoch, name, acc))
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
- name, val_acc = test(ctx)
+ name, val_acc = test(ctx, val_data)
logging.info('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc))
net.save_params('image-classifier-%s-%d.params'%(opt.model, epochs))
@@ -175,10 +184,12 @@ def main():
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in
range(num_gpus)] if num_gpus > 0 else [mx.cpu()])
+ kv = mx.kv.create(opt.kvstore)
+ train_data, val_data = get_data_iters(dataset, batch_size,
kv.num_workers, kv.rank)
mod.fit(train_data,
eval_data = val_data,
num_epoch=opt.epochs,
- kvstore=opt.kvstore,
+ kvstore=kv,
batch_end_callback = mx.callback.Speedometer(batch_size,
max(1, opt.log_interval)),
epoch_end_callback =
mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
optimizer = 'sgd',
diff --git a/example/multi-task/example_multi_task.py
b/example/multi-task/example_multi_task.py
index 9ea9ad0173..9e898494a1 100644
--- a/example/multi-task/example_multi_task.py
+++ b/example/multi-task/example_multi_task.py
@@ -16,12 +16,8 @@
# under the License.
# pylint: skip-file
-import sys
-import os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-from get_data import MNISTIterator
import mxnet as mx
+from mxnet.test_utils import get_mnist_iterator
import numpy as np
import logging
import time
@@ -142,7 +138,7 @@ def get_name_value(self):
lr = 0.01
network = build_network()
-train, val = MNISTIterator(batch_size=batch_size, input_shape = (784,))
+train, val = get_mnist_iterator(batch_size=batch_size, input_shape = (784,))
train = Multi_mnist_iterator(train)
val = Multi_mnist_iterator(val)
diff --git a/example/numpy-ops/custom_softmax.py
b/example/numpy-ops/custom_softmax.py
index a2ec5d54b7..ab94401185 100644
--- a/example/numpy-ops/custom_softmax.py
+++ b/example/numpy-ops/custom_softmax.py
@@ -16,12 +16,8 @@
# under the License.
# pylint: skip-file
-import sys
-import os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-from get_data import MNISTIterator
import mxnet as mx
+from mxnet.test_utils import get_mnist_iterator
import numpy as np
import logging
@@ -75,7 +71,7 @@ def create_operator(self, ctx, shapes, dtypes):
# data
-train, val = MNISTIterator(batch_size=100, input_shape = (784,))
+train, val = get_mnist_iterator(batch_size=100, input_shape = (784,))
# train
diff --git a/example/numpy-ops/ndarray_softmax.py
b/example/numpy-ops/ndarray_softmax.py
index 4ced2c5cd8..58eab3d538 100644
--- a/example/numpy-ops/ndarray_softmax.py
+++ b/example/numpy-ops/ndarray_softmax.py
@@ -16,16 +16,11 @@
# under the License.
# pylint: skip-file
-import os
-import sys
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-from get_data import MNISTIterator
import mxnet as mx
+from mxnet.test_utils import get_mnist_iterator
import numpy as np
import logging
-
class NDArraySoftmax(mx.operator.NDArrayOp):
def __init__(self):
super(NDArraySoftmax, self).__init__(False)
@@ -97,7 +92,7 @@ def backward(self, out_grad, in_data, out_data, in_grad):
# data
-train, val = MNISTIterator(batch_size=100, input_shape = (784,))
+train, val = get_mnist_iterator(batch_size=100, input_shape = (784,))
# train
diff --git a/example/numpy-ops/numpy_softmax.py
b/example/numpy-ops/numpy_softmax.py
index c10dfe3779..88d2473492 100644
--- a/example/numpy-ops/numpy_softmax.py
+++ b/example/numpy-ops/numpy_softmax.py
@@ -16,12 +16,8 @@
# under the License.
# pylint: skip-file
-import sys
-import os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-from get_data import MNISTIterator
import mxnet as mx
+from mxnet.test_utils import get_mnist_iterator
import numpy as np
import logging
@@ -70,7 +66,7 @@ def backward(self, out_grad, in_data, out_data, in_grad):
# data
-train, val = MNISTIterator(batch_size=100, input_shape = (784,))
+train, val = get_mnist_iterator(batch_size=100, input_shape = (784,))
# train
diff --git a/example/python-howto/monitor_weights.py
b/example/python-howto/monitor_weights.py
index ab77b4908b..929b0e7bf7 100644
--- a/example/python-howto/monitor_weights.py
+++ b/example/python-howto/monitor_weights.py
@@ -16,12 +16,8 @@
# under the License.
# pylint: skip-file
-import sys
-import os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
-from get_data import MNISTIterator
import mxnet as mx
+from mxnet.test_utils import get_mnist_iterator
import numpy as np
import logging
@@ -35,7 +31,7 @@
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
# data
-train, val = MNISTIterator(batch_size=100, input_shape = (784,))
+train, val = get_mnist_iterator(batch_size=100, input_shape = (784,))
# monitor
def norm_stat(d):
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 53814b766f..1b41011fd9 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1441,6 +1441,74 @@ def read_data(label_url, image_url):
return {'train_data':train_img, 'train_label':train_lbl,
'test_data':test_img, 'test_label':test_lbl}
+def get_mnist_pkl():
+ """Downloads MNIST dataset as a pkl.gz into a directory in the current
directory
+ with the name `data`
+ """
+ if not os.path.isdir("data"):
+ os.makedirs('data')
+ if not os.path.exists('data/mnist.pkl.gz'):
+ download('http://deeplearning.net/data/mnist/mnist.pkl.gz',
+ dirname='data')
+
+def get_mnist_ubyte():
+ """Downloads ubyte version of the MNIST dataset into a directory in the
current directory
+ with the name `data` and extracts all files in the zip archive to this
directory.
+ """
+ if not os.path.isdir("data"):
+ os.makedirs('data')
+ if (not os.path.exists('data/train-images-idx3-ubyte')) or \
+ (not os.path.exists('data/train-labels-idx1-ubyte')) or \
+ (not os.path.exists('data/t10k-images-idx3-ubyte')) or \
+ (not os.path.exists('data/t10k-labels-idx1-ubyte')):
+ zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
+ dirname='data')
+ with zipfile.ZipFile(zip_file_path) as zf:
+ zf.extractall('data')
+
+def get_cifar10():
+ """Downloads CIFAR10 dataset into a directory in the current directory
with the name `data`,
+ and then extracts all files into the directory `data/cifar`.
+ """
+ if not os.path.isdir("data"):
+ os.makedirs('data')
+ if (not os.path.exists('data/cifar/train.rec')) or \
+ (not os.path.exists('data/cifar/test.rec')) or \
+ (not os.path.exists('data/cifar/train.lst')) or \
+ (not os.path.exists('data/cifar/test.lst')):
+ zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip',
+ dirname='data')
+ with zipfile.ZipFile(zip_file_path) as zf:
+ zf.extractall('data')
+
+def get_mnist_iterator(batch_size, input_shape, num_parts=1, part_index=0):
+ """Returns training and validation iterators for MNIST dataset
+ """
+
+ get_mnist_ubyte()
+ flat = False if len(input_shape) == 3 else True
+
+ train_dataiter = mx.io.MNISTIter(
+ image="data/train-images-idx3-ubyte",
+ label="data/train-labels-idx1-ubyte",
+ input_shape=input_shape,
+ batch_size=batch_size,
+ shuffle=True,
+ flat=flat,
+ num_parts=num_parts,
+ part_index=part_index)
+
+ val_dataiter = mx.io.MNISTIter(
+ image="data/t10k-images-idx3-ubyte",
+ label="data/t10k-labels-idx1-ubyte",
+ input_shape=input_shape,
+ batch_size=batch_size,
+ flat=flat,
+ num_parts=num_parts,
+ part_index=part_index)
+
+ return (train_dataiter, val_dataiter)
+
def get_zip_data(data_dir, url, data_origin_name):
"""Download and extract zip data.
diff --git a/tests/python/common/get_data.py b/tests/python/common/get_data.py
deleted file mode 100644
index 5802a06919..0000000000
--- a/tests/python/common/get_data.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: skip-file
-import os, gzip
-import pickle as pickle
-import sys
-from mxnet.test_utils import download
-import zipfile
-import mxnet as mx
-
-# download mnist.pkl.gz
-def GetMNIST_pkl():
- if not os.path.isdir("data"):
- os.makedirs('data')
- if not os.path.exists('data/mnist.pkl.gz'):
- download('http://deeplearning.net/data/mnist/mnist.pkl.gz',
- dirname='data')
-
-# download ubyte version of mnist and untar
-def GetMNIST_ubyte():
- if not os.path.isdir("data"):
- os.makedirs('data')
- if (not os.path.exists('data/train-images-idx3-ubyte')) or \
- (not os.path.exists('data/train-labels-idx1-ubyte')) or \
- (not os.path.exists('data/t10k-images-idx3-ubyte')) or \
- (not os.path.exists('data/t10k-labels-idx1-ubyte')):
- zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
- dirname='data')
- with zipfile.ZipFile(zip_file_path) as zf:
- zf.extractall('data')
-
-# download cifar
-def GetCifar10():
- if not os.path.isdir("data"):
- os.makedirs('data')
- if (not os.path.exists('data/cifar/train.rec')) or \
- (not os.path.exists('data/cifar/test.rec')) or \
- (not os.path.exists('data/cifar/train.lst')) or \
- (not os.path.exists('data/cifar/test.lst')):
- zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip',
- dirname='data')
- with zipfile.ZipFile(zip_file_path) as zf:
- zf.extractall('data')
-
-def MNISTIterator(batch_size, input_shape):
- """return train and val iterators for mnist"""
- # download data
- GetMNIST_ubyte()
- flat = False if len(input_shape) == 3 else True
-
- train_dataiter = mx.io.MNISTIter(
- image="data/train-images-idx3-ubyte",
- label="data/train-labels-idx1-ubyte",
- input_shape=input_shape,
- batch_size=batch_size,
- shuffle=True,
- flat=flat)
-
- val_dataiter = mx.io.MNISTIter(
- image="data/t10k-images-idx3-ubyte",
- label="data/t10k-labels-idx1-ubyte",
- input_shape=input_shape,
- batch_size=batch_size,
- flat=flat)
-
- return (train_dataiter, val_dataiter)
diff --git a/tests/python/train/test_autograd.py
b/tests/python/train/test_autograd.py
index c9921ecf4f..712672cd0a 100644
--- a/tests/python/train/test_autograd.py
+++ b/tests/python/train/test_autograd.py
@@ -21,9 +21,9 @@
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
+from mxnet.test_utils import get_mnist_ubyte
import numpy as np
import logging
-from common import get_data
from mxnet import autograd
logging.basicConfig(level=logging.DEBUG)
@@ -36,7 +36,7 @@ def get_net():
net.add(nn.Dense(10, prefix='fc3_'))
return net
-get_data.GetMNIST_ubyte()
+get_mnist_ubyte()
batch_size = 100
train_data = mx.io.MNISTIter(
diff --git a/tests/python/train/test_conv.py b/tests/python/train/test_conv.py
index 46e06848f8..adceb3ebc1 100644
--- a/tests/python/train/test_conv.py
+++ b/tests/python/train/test_conv.py
@@ -19,10 +19,10 @@
import sys
sys.path.insert(0, '../../python')
import mxnet as mx
+from mxnet.test_utils import get_mnist_ubyte
import numpy as np
import os, pickle, gzip, argparse
import logging
-from common import get_data
def get_model(use_gpu):
# symbol net
@@ -52,7 +52,7 @@ def get_model(use_gpu):
def get_iters():
# check data
- get_data.GetMNIST_ubyte()
+ get_mnist_ubyte()
batch_size = 100
train_dataiter = mx.io.MNISTIter(
diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py
index 52f04bf9a1..2e3ff06d2e 100644
--- a/tests/python/train/test_dtype.py
+++ b/tests/python/train/test_dtype.py
@@ -22,7 +22,7 @@
import numpy as np
import os, pickle, gzip
import logging
-from common import get_data
+from mxnet.test_utils import get_cifar10
batch_size = 128
@@ -39,7 +39,7 @@ def get_net():
return softmax
# check data
-get_data.GetCifar10()
+get_cifar10()
def get_iterator_uint8(kv):
data_shape = (3, 28, 28)
diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py
index a0a45b41e1..1b8e06f530 100644
--- a/tests/python/train/test_mlp.py
+++ b/tests/python/train/test_mlp.py
@@ -21,7 +21,7 @@
import os, sys
import pickle as pickle
import logging
-from common import get_data
+from mxnet.test_utils import get_mnist_ubyte
# symbol net
batch_size = 100
@@ -41,7 +41,7 @@ def accuracy(label, pred):
prefix = './mlp'
#check data
-get_data.GetMNIST_ubyte()
+get_mnist_ubyte()
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index 12ed60d2bc..782534bf80 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -21,8 +21,6 @@
sys.path.insert(0, os.path.join(curr_path, '../../../python'))
import models
-import get_data
-
def assertRaises(expected_exception, func, *args, **kwargs):
try:
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index fa314e0f8b..03d829efd7 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -27,13 +27,12 @@
except ImportError:
h5py = None
import sys
-from common import get_data, assertRaises
+from common import assertRaises
import unittest
-
def test_MNISTIter():
# prepare data
- get_data.GetMNIST_ubyte()
+ get_mnist_ubyte()
batch_size = 100
train_dataiter = mx.io.MNISTIter(
@@ -61,7 +60,7 @@ def test_MNISTIter():
assert(sum(label_0 - label_1) == 0)
def test_Cifar10Rec():
- get_data.GetCifar10()
+ get_cifar10()
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services