This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.7.x by this push:
new 8695537 add logic for no batch size while getting data arrays from
executors (#17772) (#18122)
8695537 is described below
commit 86955370cd868b5d4f46f2f80f7632fd864773e3
Author: Manu Seth <[email protected]>
AuthorDate: Thu Apr 23 01:14:44 2020 -0700
add logic for no batch size while getting data arrays from executors
(#17772) (#18122)
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
---
python/mxnet/module/executor_group.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/module/executor_group.py
b/python/mxnet/module/executor_group.py
index d47665d..f2cb62f 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -308,8 +308,16 @@ class DataParallelExecutorGroup(object):
def _collect_arrays(self):
"""Collect internal arrays from executors."""
# convenient data structures
- self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in
enumerate(self.execs)]
- for name, _ in self.data_shapes]
+
+ # check if self.slices is populated, if not then that means that there
is no batch size
+ if self.slices:
+ # based on batch size, slice up data for the given contexts
(self.execs)
+ self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e
in enumerate(self.execs)]
+ for name, _ in self.data_shapes]
+ else:
+ # just use the context index as index into the data
+ self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in
enumerate(self.execs)]
+ for name, _ in self.data_shapes]
self.state_arrays = [[e.arg_dict[name] for e in self.execs]
for name in self.state_names]