This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new daabe5c  split_and_load can now handle num_ctx > num_data. Issue 
#13909 (#14607)
daabe5c is described below

commit daabe5c984e8e69fe0c799bfd04bf1dc548d7f86
Author: Young Seok Tony Kim <[email protected]>
AuthorDate: Mon Apr 8 00:23:04 2019 -0700

    split_and_load can now handle num_ctx > num_data. Issue #13909 (#14607)
---
 python/mxnet/gluon/utils.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 55edd95..b00cc04 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -63,10 +63,6 @@ def split_data(data, num_slice, batch_axis=0, 
even_split=True):
         Return value is a list even if `num_slice` is 1.
     """
     size = data.shape[batch_axis]
-    if size < num_slice:
-        raise ValueError(
-            "Too many slices for data with shape %s. Arguments are " \
-            "num_slice=%d and batch_axis=%d."%(str(data.shape), num_slice, 
batch_axis))
     if even_split and size % num_slice != 0:
         raise ValueError(
             "data with shape %s cannot be evenly split into %d slices along 
axis %d. " \
@@ -75,6 +71,12 @@ def split_data(data, num_slice, batch_axis=0, 
even_split=True):
                 str(data.shape), num_slice, batch_axis, num_slice))
 
     step = size // num_slice
+
+    # If size < num_slice, make fewer slices
+    if not even_split and size < num_slice:
+        step = 1
+        num_slice = size
+
     if batch_axis == 0:
         slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else 
data[i*step:size]
                   for i in range(num_slice)]

Reply via email to