sandeep-krishnamurthy closed pull request #9877: Better even_split=False
support in gluon.split_data()
URL: https://github.com/apache/incubator-mxnet/pull/9877
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 88effc99b9b..8ac929dfa8a 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -56,13 +56,10 @@ def split_data(data, num_slice, batch_axis=0,
even_split=True):
Returns
-------
list of NDArray
- Return value is a list even if `num_slice` is 1.
+ Return value is a list even if `num_slice` is 1. When `even_split`
+ is `False` this may be shorter than `num_slice`.
"""
size = data.shape[batch_axis]
- if size < num_slice:
- raise ValueError(
- "Too many slices for data with shape %s. Arguments are " \
- "num_slice=%d and batch_axis=%d."%(str(data.shape), num_slice,
batch_axis))
if even_split and size % num_slice != 0:
raise ValueError(
"data with shape %s cannot be evenly split into %d slices along
axis %d. " \
@@ -71,16 +68,19 @@ def split_data(data, num_slice, batch_axis=0,
even_split=True):
str(data.shape), num_slice, batch_axis, num_slice))
step = size // num_slice
- if batch_axis == 0:
- slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else
data[i*step:size]
- for i in range(num_slice)]
- elif even_split:
+ rem = size % num_slice
+
+ if rem == 0:
slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis)
else:
- slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step)
- if i < num_slice - 1 else
- ndarray.slice_axis(data, batch_axis, i*step, size)
- for i in range(num_slice)]
+ # First `rem` slices will have an extra sample
+ slices = [ndarray.slice_axis(data, batch_axis, i*(step+1),
(i+1)*(step+1))
+ for i in range(rem)]
+ offset = rem*(step+1)
+ # Create the remaining slices
+ if step > 0:
+ slices += [ndarray.slice_axis(data, batch_axis, offset+i*step,
offset+(i+1)*step)
+ for i in range(num_slice-rem)]
return slices
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index 50b60a2db3e..12c2e394d11 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -404,9 +404,11 @@ def test_deferred_init():
layer(x)
-def check_split_data(x, num_slice, batch_axis, **kwargs):
+def check_split_data(x, num_slice, batch_axis, slice_shapes, **kwargs):
res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs)
- assert len(res) == num_slice
+ assert len(res) == len(slice_shapes)
+ for res_slice, slice_shape in zip(res, slice_shapes):
+ assert res_slice.shape == slice_shape
mx.test_utils.assert_almost_equal(mx.nd.concat(*res,
dim=batch_axis).asnumpy(),
x.asnumpy())
@@ -415,15 +417,16 @@ def check_split_data(x, num_slice, batch_axis, **kwargs):
def test_split_data():
x = mx.nd.random.uniform(shape=(128, 33, 64))
- check_split_data(x, 8, 0)
- check_split_data(x, 3, 1)
- check_split_data(x, 4, 1, even_split=False)
- check_split_data(x, 15, 1, even_split=False)
+ check_split_data(x, 8, 0, ((16, 33, 64),)*8)
+ check_split_data(x, 3, 1, ((128, 11, 64),)*3)
+ check_split_data(x, 4, 1, ((128, 9, 64),) + ((128, 8, 64),)*3,
even_split=False)
+ check_split_data(x, 15, 1, ((128, 3, 64),)*3 + ((128, 2, 64),)*12,
even_split=False)
+ check_split_data(x, 70, 2, ((128, 33, 1),)*64, even_split=False)
try:
- check_split_data(x, 4, 1)
+ check_split_data(x, 4, 1, ((128, 9, 64),) + ((128, 8, 64),)*3)
except ValueError:
return
- assert False, "Should have failed"
+ assert False, "Should have failed because even_split=True"
@with_seed()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services