You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/26 12:04:52 UTC
[incubator-mxnet] branch v1.x updated: add logic for no batch size
while getting data arrays from executors (#17772) (#18075)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 63e2b19 add logic for no batch size while getting data arrays from executors (#17772) (#18075)
63e2b19 is described below
commit 63e2b19ea79f5840a9169441a98e64e0eb007956
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Sun Apr 26 05:03:24 2020 -0700
add logic for no batch size while getting data arrays from executors (#17772) (#18075)
Co-authored-by: Ubuntu <ub...@ip-172-31-94-123.ec2.internal>
Co-authored-by: Ubuntu <ub...@ip-172-31-94-123.ec2.internal>
---
python/mxnet/module/executor_group.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index d47665d..f2cb62f 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -308,8 +308,16 @@ class DataParallelExecutorGroup(object):
def _collect_arrays(self):
"""Collect internal arrays from executors."""
# convenient data structures
- self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
- for name, _ in self.data_shapes]
+
+ # check if self.slices is populated, if not then that means that there is no batch size
+ if self.slices:
+ # based on batch size, slice up data for the given contexts (self.execs)
+ self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
+ for name, _ in self.data_shapes]
+ else:
+ # just use the context index as index into the data
+ self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)]
+ for name, _ in self.data_shapes]
self.state_arrays = [[e.arg_dict[name] for e in self.execs]
for name in self.state_names]