You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/26 12:04:52 UTC

[incubator-mxnet] branch v1.x updated: add logic for no batch size while getting data arrays from executors (#17772) (#18075)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 63e2b19  add logic for no batch size while getting data arrays from executors (#17772) (#18075)
63e2b19 is described below

commit 63e2b19ea79f5840a9169441a98e64e0eb007956
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Sun Apr 26 05:03:24 2020 -0700

    add logic for no batch size while getting data arrays from executors (#17772) (#18075)
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-94-123.ec2.internal>
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-94-123.ec2.internal>
---
 python/mxnet/module/executor_group.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index d47665d..f2cb62f 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -308,8 +308,16 @@ class DataParallelExecutorGroup(object):
     def _collect_arrays(self):
         """Collect internal arrays from executors."""
         # convenient data structures
-        self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
-                            for name, _ in self.data_shapes]
+
+        # check if self.slices is populated, if not then that means that there is no batch size
+        if self.slices:
+            # based on batch size, slice up data for the given contexts (self.execs)
+            self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
+                                for name, _ in self.data_shapes]
+        else:
+            # just use the context index as index into the data
+            self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)]
+                                for name, _ in self.data_shapes]
 
         self.state_arrays = [[e.arg_dict[name] for e in self.execs]
                              for name in self.state_names]