You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/12 00:23:05 UTC
[GitHub] piiswrong closed pull request #11112: Support for data iterators
returning lists of batches
piiswrong closed pull request #11112: Support for data iterators returning lists of batches
URL: https://github.com/apache/incubator-mxnet/pull/11112
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index f5427feae2f..b646ef6edaa 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -159,8 +159,13 @@ def fit(args, network, data_loader, **kwargs):
if args.test_io:
tic = time.time()
for i, batch in enumerate(train):
- for j in batch.data:
- j.wait_to_read()
+ if isinstance(batch, list):
+ for b in batch:
+ for j in b.data:
+ j.wait_to_read()
+ else:
+ for j in batch.data:
+ j.wait_to_read()
if (i + 1) % args.disp_batches == 0:
logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
args.disp_batches * args.batch_size / (time.time() - tic))
diff --git a/python/mxnet/executor_manager.py b/python/mxnet/executor_manager.py
index 33c6c976271..825aa76e43c 100644
--- a/python/mxnet/executor_manager.py
+++ b/python/mxnet/executor_manager.py
@@ -286,10 +286,13 @@ def backward(self):
for texec in self.train_execs:
texec.backward()
- def update_metric(self, metric, labels):
+ def update_metric(self, metric, labels, pre_sliced=False):
"""Update evaluation metric with label and current outputs."""
- for texec, islice in zip(self.train_execs, self.slices):
- labels_slice = [label[islice] for label in labels]
+ for current_exec, (texec, islice) in enumerate(zip(self.train_execs, self.slices)):
+ if not pre_sliced:
+ labels_slice = [label[islice] for label in labels]
+ else:
+ labels_slice = labels[current_exec]
metric.update(labels_slice, texec.outputs)
class DataParallelExecutorManager(object):
@@ -436,6 +439,6 @@ def backward(self):
"""Run backward on the current executor."""
self.curr_execgrp.backward()
- def update_metric(self, metric, labels):
+ def update_metric(self, metric, labels, pre_sliced=False):
"""Update metric with the current executor."""
- self.curr_execgrp.update_metric(metric, labels)
+ self.curr_execgrp.update_metric(metric, labels, pre_sliced)
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index 8f5fd4ab854..4b7355ffa92 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -146,7 +146,8 @@ class BaseModule(object):
- `get_outputs()`: get outputs of the previous forward operation.
- `get_input_grads()`: get the gradients with respect to the inputs computed
in the previous backward operation.
- - `update_metric(metric, labels)`: update performance metric for the previous forward
+ - `update_metric(metric, labels, pre_sliced=False)`: update performance metric
+ for the previous forward
computed results.
- other properties (mostly for backward compatibility)
@@ -249,7 +250,10 @@ def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
break
self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
self.forward(eval_batch, is_train=False)
- self.update_metric(eval_metric, eval_batch.label)
+ if isinstance(eval_batch, list):
+ self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
+ else:
+ self.update_metric(eval_metric, eval_batch.label)
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch,
@@ -517,7 +521,12 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
except StopIteration:
end_of_batch = True
- self.update_metric(eval_metric, data_batch.label)
+ if isinstance(data_batch, list):
+ self.update_metric(eval_metric,
+ [db.label for db in data_batch],
+ pre_sliced=True)
+ else:
+ self.update_metric(eval_metric, data_batch.label)
if monitor is not None:
monitor.toc_print()
@@ -943,7 +952,7 @@ def update(self):
"""
raise NotImplementedError()
- def update_metric(self, eval_metric, labels):
+ def update_metric(self, eval_metric, labels, pre_sliced=False):
"""Evaluates and accumulates evaluation metric on outputs of the last forward
computation.
@@ -951,8 +960,10 @@ def update_metric(self, eval_metric, labels):
----------
eval_metric : EvalMetric
Evaluation metric to use.
- labels : list of NDArray
- Typically `data_batch.label`.
+ labels : list of NDArray if `pre_sliced` parameter is set to `False`,
+ list of lists of NDArray otherwise. Typically `data_batch.label`.
+ pre_sliced: bool
+ Whether the labels are already sliced per device (default: False).
Examples
--------
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 18cec29b409..9b568618566 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -517,7 +517,7 @@ def get_input_grads(self, merge_multi_context=True):
assert self.binded and self.params_initialized and self.inputs_need_grad
return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context)
- def update_metric(self, eval_metric, labels):
+ def update_metric(self, eval_metric, labels, pre_sliced=False):
"""Evaluates and accumulates evaluation metric on outputs of the last forward computation.
Parameters
@@ -527,7 +527,7 @@ def update_metric(self, eval_metric, labels):
Typically ``data_batch.label``.
"""
assert self.binded and self.params_initialized
- self._curr_module.update_metric(eval_metric, labels)
+ self._curr_module.update_metric(eval_metric, labels, pre_sliced)
@property
def symbol(self):
diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py
index 32400c11dbc..5d8e95077c4 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -64,12 +64,26 @@ def _load_general(data, targets, major_axis):
def _load_data(batch, targets, major_axis):
"""Load data into sliced arrays."""
- _load_general(batch.data, targets, major_axis)
+ if isinstance(batch, list):
+ new_batch = []
+ for i in range(len(targets)):
+ new_batch.append([b.data[i] for b in batch])
+ new_targets = [[dst for _, dst in d_target] for d_target in targets]
+ _load_general(new_batch, new_targets, major_axis)
+ else:
+ _load_general(batch.data, targets, major_axis)
def _load_label(batch, targets, major_axis):
"""Load label into sliced arrays."""
- _load_general(batch.label, targets, major_axis)
+ if isinstance(batch, list):
+ new_batch = []
+ for i in range(len(targets)):
+ new_batch.append([b.label[i] for b in batch])
+ new_targets = [[dst for _, dst in d_target] for d_target in targets]
+ _load_general(new_batch, new_targets, major_axis)
+ else:
+ _load_general(batch.label, targets, major_axis)
def _merge_multi_context(outputs, major_axis):
@@ -437,8 +451,12 @@ def forward(self, data_batch, is_train=None):
if is_train is None:
is_train = self.for_training
- if self.label_arrays is not None and data_batch.label:
- _load_label(data_batch, self.label_arrays, self.label_layouts)
+ if isinstance(data_batch, list):
+ if self.label_arrays is not None and data_batch is not None and data_batch[0].label:
+ _load_label(data_batch, self.label_arrays, self.label_layouts)
+ else:
+ if self.label_arrays is not None and data_batch.label:
+ _load_label(data_batch, self.label_arrays, self.label_layouts)
for exec_ in self.execs:
exec_.forward(is_train=is_train)
@@ -580,7 +598,7 @@ def backward(self, out_grads=None):
out_grads_slice.append(grad.copyto(self.contexts[i]))
exec_.backward(out_grads=out_grads_slice)
- def update_metric(self, eval_metric, labels):
+ def update_metric(self, eval_metric, labels, pre_sliced):
"""Accumulate the performance according to `eval_metric` on all devices
by comparing outputs from [begin, end) to labels. By default use all
outputs.
@@ -591,25 +609,30 @@ def update_metric(self, eval_metric, labels):
The metric used for evaluation.
labels : list of NDArray
Typically comes from `label` of a `DataBatch`.
+ pre_sliced : bool
+ Whether labels are already sliced.
begin : int
Starting index of used outputs.
end : int or None
Ending index of used outputs.
"""
- for texec, islice in zip(self.execs, self.slices):
- labels_slice = []
- for label, axis in zip(labels, self.label_layouts):
- if axis == 0:
- # slicing NDArray along axis 0 can avoid copying
- labels_slice.append(label[islice])
- elif axis > 0:
- # pylint: disable=no-member
- label_my_slice = nd.slice_axis(label, axis=axis, begin=islice.start,
- end=islice.stop).as_in_context(label.context)
- # pylint: enable=no-member
- labels_slice.append(label_my_slice)
- else:
- labels_slice.append(label)
+ for current_exec, (texec, islice) in enumerate(zip(self.execs, self.slices)):
+ if not pre_sliced:
+ labels_slice = []
+ for label, axis in zip(labels, self.label_layouts):
+ if axis == 0:
+ # slicing NDArray along axis 0 can avoid copying
+ labels_slice.append(label[islice])
+ elif axis > 0:
+ # pylint: disable=no-member
+ label_my_slice = nd.slice_axis(label, axis=axis, begin=islice.start,
+ end=islice.stop).as_in_context(label.context)
+ # pylint: enable=no-member
+ labels_slice.append(label_my_slice)
+ else:
+ labels_slice.append(label)
+ else:
+ labels_slice = labels[current_exec]
labels_ = OrderedDict(zip(self.label_names, labels_slice))
preds = OrderedDict(zip(self.output_names, texec.outputs))
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index a05c3a31cd2..4d77e0e4d8c 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -590,7 +590,19 @@ def forward(self, data_batch, is_train=None):
assert self.binded and self.params_initialized
curr_data_shapes = tuple(i.shape for i in self._data_shapes)
- new_data_shapes = tuple(i.shape for i in data_batch.data)
+ if isinstance(data_batch, list):
+ assert data_batch is not None, "Encountered empty data batch"
+ new_data_shapes = []
+ for i in range(len(data_batch[0].data)):
+ shape = data_batch[0].data[i].shape
+ for db in data_batch:
+ assert shape == db.data[i].shape, \
+ "All data batches in a list need to have the same shape"
+ new_batch_size = len(data_batch) * shape[0]
+ new_data_shapes.append((new_batch_size,) + shape[1:])
+ new_data_shapes = tuple(new_data_shapes)
+ else:
+ new_data_shapes = tuple(i.shape for i in data_batch.data)
if curr_data_shapes != new_data_shapes:
if hasattr(data_batch, "provide_data") and data_batch.provide_data:
@@ -741,7 +753,7 @@ def set_states(self, states=None, value=None):
assert self.binded and self.params_initialized
self._exec_group.set_states(states, value)
- def update_metric(self, eval_metric, labels):
+ def update_metric(self, eval_metric, labels, pre_sliced=False):
"""Evaluates and accumulates evaluation metric on outputs of the last forward computation.
See Also
@@ -751,10 +763,13 @@ def update_metric(self, eval_metric, labels):
Parameters
----------
eval_metric : EvalMetric
- labels : list of NDArray
- Typically ``data_batch.label``.
+ Evaluation metric to use.
+ labels : list of NDArray if `pre_sliced` parameter is set to `False`,
+ list of lists of NDArray otherwise. Typically `data_batch.label`.
+ pre_sliced: bool
+ Whether the labels are already sliced per device (default: False).
"""
- self._exec_group.update_metric(eval_metric, labels)
+ self._exec_group.update_metric(eval_metric, labels, pre_sliced)
def _sync_params_from_devices(self):
"""Synchronizes parameters from devices to CPU. This function should be called after
diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py
index 2d4343c80c7..886851efc30 100644
--- a/python/mxnet/module/python_module.py
+++ b/python/mxnet/module/python_module.py
@@ -138,7 +138,7 @@ def update(self):
"""
pass
- def update_metric(self, eval_metric, labels):
+ def update_metric(self, eval_metric, labels, pre_sliced=False):
"""Evaluates and accumulates evaluation metric on outputs of the last forward computation.
Subclass should override this method if needed.
@@ -153,6 +153,9 @@ def update_metric(self, eval_metric, labels):
# function or predictions, so just ignore this call
return
+ if pre_sliced:
+ raise RuntimeError("PythonModule does not support presliced labels")
+
# by default we expect our outputs are some scores that could be evaluated
eval_metric.update(labels, self.get_outputs())
diff --git a/python/mxnet/module/sequential_module.py b/python/mxnet/module/sequential_module.py
index 642a398c08d..8d563a4def7 100644
--- a/python/mxnet/module/sequential_module.py
+++ b/python/mxnet/module/sequential_module.py
@@ -416,7 +416,7 @@ def get_input_grads(self, merge_multi_context=True):
assert self.binded and self.params_initialized and self.inputs_need_grad
return self._modules[0].get_input_grads(merge_multi_context=merge_multi_context)
- def update_metric(self, eval_metric, labels):
+ def update_metric(self, eval_metric, labels, pre_sliced=False):
"""Evaluates and accumulates evaluation metric on outputs of the last forward computation.
Parameters
@@ -430,7 +430,7 @@ def update_metric(self, eval_metric, labels):
for meta, module in zip(self._metas, self._modules):
if SequentialModule.META_TAKE_LABELS in meta and \
meta[SequentialModule.META_TAKE_LABELS]:
- module.update_metric(eval_metric, labels)
+ module.update_metric(eval_metric, labels, pre_sliced)
def install_monitor(self, mon):
"""Installs monitor on all executors."""
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services