You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/12 00:23:05 UTC

[GitHub] piiswrong closed pull request #11112: Support for data iterators returning lists of batches

piiswrong closed pull request #11112: Support for data iterators returning lists of batches
URL: https://github.com/apache/incubator-mxnet/pull/11112
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index f5427feae2f..b646ef6edaa 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -159,8 +159,13 @@ def fit(args, network, data_loader, **kwargs):
     if args.test_io:
         tic = time.time()
         for i, batch in enumerate(train):
-            for j in batch.data:
-                j.wait_to_read()
+            if isinstance(batch, list):
+                for b in batch:
+                    for j in b.data:
+                        j.wait_to_read()
+            else:
+                for j in batch.data:
+                    j.wait_to_read()
             if (i + 1) % args.disp_batches == 0:
                 logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
                              args.disp_batches * args.batch_size / (time.time() - tic))
diff --git a/python/mxnet/executor_manager.py b/python/mxnet/executor_manager.py
index 33c6c976271..825aa76e43c 100644
--- a/python/mxnet/executor_manager.py
+++ b/python/mxnet/executor_manager.py
@@ -286,10 +286,13 @@ def backward(self):
         for texec in self.train_execs:
             texec.backward()
 
-    def update_metric(self, metric, labels):
+    def update_metric(self, metric, labels, pre_sliced=False):
         """Update evaluation metric with label and current outputs."""
-        for texec, islice in zip(self.train_execs, self.slices):
-            labels_slice = [label[islice] for label in labels]
+        for current_exec, (texec, islice) in enumerate(zip(self.train_execs, self.slices)):
+            if not pre_sliced:
+                labels_slice = [label[islice] for label in labels]
+            else:
+                labels_slice = labels[current_exec]
             metric.update(labels_slice, texec.outputs)
 
 class DataParallelExecutorManager(object):
@@ -436,6 +439,6 @@ def backward(self):
         """Run backward on the current executor."""
         self.curr_execgrp.backward()
 
-    def update_metric(self, metric, labels):
+    def update_metric(self, metric, labels, pre_sliced=False):
         """Update metric with the current executor."""
-        self.curr_execgrp.update_metric(metric, labels)
+        self.curr_execgrp.update_metric(metric, labels, pre_sliced)
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index 8f5fd4ab854..4b7355ffa92 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -146,7 +146,8 @@ class BaseModule(object):
         - `get_outputs()`: get outputs of the previous forward operation.
         - `get_input_grads()`: get the gradients with respect to the inputs computed
           in the previous backward operation.
-        - `update_metric(metric, labels)`: update performance metric for the previous forward
+        - `update_metric(metric, labels, pre_sliced=False)`: update performance metric
+          for the previous forward
           computed results.
 
     - other properties (mostly for backward compatibility)
@@ -249,7 +250,10 @@ def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
                 break
             self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
             self.forward(eval_batch, is_train=False)
-            self.update_metric(eval_metric, eval_batch.label)
+            if isinstance(eval_batch, list):
+                self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
+            else:
+                self.update_metric(eval_metric, eval_batch.label)
 
             if batch_end_callback is not None:
                 batch_end_params = BatchEndParam(epoch=epoch,
@@ -517,7 +521,12 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
                 except StopIteration:
                     end_of_batch = True
 
-                self.update_metric(eval_metric, data_batch.label)
+                if isinstance(data_batch, list):
+                    self.update_metric(eval_metric,
+                                       [db.label for db in data_batch],
+                                       pre_sliced=True)
+                else:
+                    self.update_metric(eval_metric, data_batch.label)
 
                 if monitor is not None:
                     monitor.toc_print()
@@ -943,7 +952,7 @@ def update(self):
         """
         raise NotImplementedError()
 
-    def update_metric(self, eval_metric, labels):
+    def update_metric(self, eval_metric, labels, pre_sliced=False):
         """Evaluates and accumulates evaluation metric on outputs of the last forward
         computation.
 
@@ -951,8 +960,10 @@ def update_metric(self, eval_metric, labels):
         ----------
         eval_metric : EvalMetric
             Evaluation metric to use.
-        labels : list of NDArray
-            Typically `data_batch.label`.
+        labels : list of NDArray if `pre_sliced` parameter is set to `False`,
+            list of lists of NDArray otherwise. Typically `data_batch.label`.
+        pre_sliced: bool
+            Whether the labels are already sliced per device (default: False).
 
         Examples
         --------
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 18cec29b409..9b568618566 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -517,7 +517,7 @@ def get_input_grads(self, merge_multi_context=True):
         assert self.binded and self.params_initialized and self.inputs_need_grad
         return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context)
 
-    def update_metric(self, eval_metric, labels):
+    def update_metric(self, eval_metric, labels, pre_sliced=False):
         """Evaluates and accumulates evaluation metric on outputs of the last forward computation.
 
         Parameters
@@ -527,7 +527,7 @@ def update_metric(self, eval_metric, labels):
             Typically ``data_batch.label``.
         """
         assert self.binded and self.params_initialized
-        self._curr_module.update_metric(eval_metric, labels)
+        self._curr_module.update_metric(eval_metric, labels, pre_sliced)
 
     @property
     def symbol(self):
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 32400c11dbc..5d8e95077c4 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -64,12 +64,26 @@ def _load_general(data, targets, major_axis):
 
 def _load_data(batch, targets, major_axis):
     """Load data into sliced arrays."""
-    _load_general(batch.data, targets, major_axis)
+    if isinstance(batch, list):
+        new_batch = []
+        for i in range(len(targets)):
+            new_batch.append([b.data[i] for b in batch])
+        new_targets = [[dst for _, dst in d_target] for d_target in targets]
+        _load_general(new_batch, new_targets, major_axis)
+    else:
+        _load_general(batch.data, targets, major_axis)
 
 
 def _load_label(batch, targets, major_axis):
     """Load label into sliced arrays."""
-    _load_general(batch.label, targets, major_axis)
+    if isinstance(batch, list):
+        new_batch = []
+        for i in range(len(targets)):
+            new_batch.append([b.label[i] for b in batch])
+        new_targets = [[dst for _, dst in d_target] for d_target in targets]
+        _load_general(new_batch, new_targets, major_axis)
+    else:
+        _load_general(batch.label, targets, major_axis)
 
 
 def _merge_multi_context(outputs, major_axis):
@@ -437,8 +451,12 @@ def forward(self, data_batch, is_train=None):
         if is_train is None:
             is_train = self.for_training
 
-        if self.label_arrays is not None and data_batch.label:
-            _load_label(data_batch, self.label_arrays, self.label_layouts)
+        if isinstance(data_batch, list):
+            if self.label_arrays is not None and data_batch is not None and data_batch[0].label:
+                _load_label(data_batch, self.label_arrays, self.label_layouts)
+        else:
+            if self.label_arrays is not None and data_batch.label:
+                _load_label(data_batch, self.label_arrays, self.label_layouts)
 
         for exec_ in self.execs:
             exec_.forward(is_train=is_train)
@@ -580,7 +598,7 @@ def backward(self, out_grads=None):
                     out_grads_slice.append(grad.copyto(self.contexts[i]))
             exec_.backward(out_grads=out_grads_slice)
 
-    def update_metric(self, eval_metric, labels):
+    def update_metric(self, eval_metric, labels, pre_sliced):
         """Accumulate the performance according to `eval_metric` on all devices
         by comparing outputs from [begin, end) to labels. By default use all
         outputs.
@@ -591,25 +609,30 @@ def update_metric(self, eval_metric, labels):
             The metric used for evaluation.
         labels : list of NDArray
             Typically comes from `label` of a `DataBatch`.
+        pre_sliced : bool
+            Whether labels are already sliced.
         begin : int
             Starting index of used outputs.
         end : int or None
             Ending index of used outputs.
         """
-        for texec, islice in zip(self.execs, self.slices):
-            labels_slice = []
-            for label, axis in zip(labels, self.label_layouts):
-                if axis == 0:
-                    # slicing NDArray along axis 0 can avoid copying
-                    labels_slice.append(label[islice])
-                elif axis > 0:
-                    # pylint: disable=no-member
-                    label_my_slice = nd.slice_axis(label, axis=axis, begin=islice.start,
-                                                   end=islice.stop).as_in_context(label.context)
-                    # pylint: enable=no-member
-                    labels_slice.append(label_my_slice)
-                else:
-                    labels_slice.append(label)
+        for current_exec, (texec, islice) in enumerate(zip(self.execs, self.slices)):
+            if not pre_sliced:
+                labels_slice = []
+                for label, axis in zip(labels, self.label_layouts):
+                    if axis == 0:
+                        # slicing NDArray along axis 0 can avoid copying
+                        labels_slice.append(label[islice])
+                    elif axis > 0:
+                        # pylint: disable=no-member
+                        label_my_slice = nd.slice_axis(label, axis=axis, begin=islice.start,
+                                                       end=islice.stop).as_in_context(label.context)
+                        # pylint: enable=no-member
+                        labels_slice.append(label_my_slice)
+                    else:
+                        labels_slice.append(label)
+            else:
+                labels_slice = labels[current_exec]
 
             labels_ = OrderedDict(zip(self.label_names, labels_slice))
             preds = OrderedDict(zip(self.output_names, texec.outputs))
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index a05c3a31cd2..4d77e0e4d8c 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -590,7 +590,19 @@ def forward(self, data_batch, is_train=None):
         assert self.binded and self.params_initialized
 
         curr_data_shapes = tuple(i.shape for i in self._data_shapes)
-        new_data_shapes = tuple(i.shape for i in data_batch.data)
+        if isinstance(data_batch, list):
+            assert data_batch is not None, "Encountered empty data batch"
+            new_data_shapes = []
+            for i in range(len(data_batch[0].data)):
+                shape = data_batch[0].data[i].shape
+                for db in data_batch:
+                    assert shape == db.data[i].shape, \
+                        "All data batches in a list need to have the same shape"
+                new_batch_size = len(data_batch) * shape[0]
+                new_data_shapes.append((new_batch_size,) + shape[1:])
+            new_data_shapes = tuple(new_data_shapes)
+        else:
+            new_data_shapes = tuple(i.shape for i in data_batch.data)
 
         if curr_data_shapes != new_data_shapes:
             if hasattr(data_batch, "provide_data") and data_batch.provide_data:
@@ -741,7 +753,7 @@ def set_states(self, states=None, value=None):
         assert self.binded and self.params_initialized
         self._exec_group.set_states(states, value)
 
-    def update_metric(self, eval_metric, labels):
+    def update_metric(self, eval_metric, labels, pre_sliced=False):
         """Evaluates and accumulates evaluation metric on outputs of the last forward computation.
 
         See Also
@@ -751,10 +763,13 @@ def update_metric(self, eval_metric, labels):
         Parameters
         ----------
         eval_metric : EvalMetric
-        labels : list of NDArray
-            Typically ``data_batch.label``.
+            Evaluation metric to use.
+        labels : list of NDArray if `pre_sliced` parameter is set to `False`,
+            list of lists of NDArray otherwise. Typically `data_batch.label`.
+        pre_sliced: bool
+            Whether the labels are already sliced per device (default: False).
         """
-        self._exec_group.update_metric(eval_metric, labels)
+        self._exec_group.update_metric(eval_metric, labels, pre_sliced)
 
     def _sync_params_from_devices(self):
         """Synchronizes parameters from devices to CPU. This function should be called after
diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py
index 2d4343c80c7..886851efc30 100644
--- a/python/mxnet/module/python_module.py
+++ b/python/mxnet/module/python_module.py
@@ -138,7 +138,7 @@ def update(self):
         """
         pass
 
-    def update_metric(self, eval_metric, labels):
+    def update_metric(self, eval_metric, labels, pre_sliced=False):
         """Evaluates and accumulates evaluation metric on outputs of the last forward computation.
         Subclass should override this method if needed.
 
@@ -153,6 +153,9 @@ def update_metric(self, eval_metric, labels):
             # function or predictions, so just ignore this call
             return
 
+        if pre_sliced:
+            raise RuntimeError("PythonModule does not support presliced labels")
+
         # by default we expect our outputs are some scores that could be evaluated
         eval_metric.update(labels, self.get_outputs())
 
diff --git a/python/mxnet/module/sequential_module.py b/python/mxnet/module/sequential_module.py
index 642a398c08d..8d563a4def7 100644
--- a/python/mxnet/module/sequential_module.py
+++ b/python/mxnet/module/sequential_module.py
@@ -416,7 +416,7 @@ def get_input_grads(self, merge_multi_context=True):
         assert self.binded and self.params_initialized and self.inputs_need_grad
         return self._modules[0].get_input_grads(merge_multi_context=merge_multi_context)
 
-    def update_metric(self, eval_metric, labels):
+    def update_metric(self, eval_metric, labels, pre_sliced=False):
         """Evaluates and accumulates evaluation metric on outputs of the last forward computation.
 
         Parameters
@@ -430,7 +430,7 @@ def update_metric(self, eval_metric, labels):
         for meta, module in zip(self._metas, self._modules):
             if SequentialModule.META_TAKE_LABELS in meta and \
                     meta[SequentialModule.META_TAKE_LABELS]:
-                module.update_metric(eval_metric, labels)
+                module.update_metric(eval_metric, labels, pre_sliced)
 
     def install_monitor(self, mon):
         """Installs monitor on all executors."""


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services