You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/06 20:27:58 UTC
[GitHub] spidyDev closed pull request #10425: [MXNET-175] Raise user warning
on mismatch between module data_names and data iter names
spidyDev closed pull request #10425: [MXNET-175] Raise user warning on mismatch between module data_names and data iter names
URL: https://github.com/apache/incubator-mxnet/pull/10425
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 884e9294741..3e25a2151ce 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -492,6 +492,7 @@ def getpad(self):
def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
assert (data is not None) or allow_empty
+ renamed_data = False
if data is None:
data = []
@@ -506,6 +507,9 @@ def _init_data(data, allow_empty, default_name):
else:
data = OrderedDict( # pylint: disable=redefined-variable-type
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
+ # This is used to identify if the data was originally a list, and was
+ # modified to a OrderedDict. Used when doing forward pass.
+ renamed_data = True
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
"a list of them or dict with them as values")
@@ -517,7 +521,7 @@ def _init_data(data, allow_empty, default_name):
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray, numpy.ndarray or h5py.Dataset")
- return list(sorted(data.items()))
+ return list(sorted(data.items())), renamed_data
def _has_instance(data, dtype):
"""Return True if ``data`` has instance of ``dtype``.
@@ -645,8 +649,8 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)
- self.data = _init_data(data, allow_empty=False, default_name=data_name)
- self.label = _init_data(label, allow_empty=True, default_name=label_name)
+ self.data, self.renamed_data = _init_data(data, allow_empty=False, default_name=data_name)
+ self.label, _ = _init_data(label, allow_empty=True, default_name=label_name)
if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index c03f8e73cda..ed836d29c02 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -25,6 +25,7 @@
from .. import metric
from .. import ndarray
+from .. import io
from ..context import cpu
from ..model import BatchEndParam
@@ -78,6 +79,17 @@ def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
return data_shapes, label_shapes
+def _check_data_names(eval_data, data_names):
+ """ Check if iterator data names match the data names provided in module"""
+ if isinstance(eval_data, io.NDArrayIter) and \
+ not eval_data.renamed_data and \
+ len(eval_data.data) and isinstance(eval_data.data[0], tuple) and \
+ dict(eval_data.data).keys() != data_names:
+ msg = "Data provided in data_names don't match names specified by iterator" \
+ " (%s vs. %s)"%(str(data_names), str(dict(eval_data.data).keys()))
+ warnings.warn(msg)
+
+
class BaseModule(object):
"""The base class of a module.
@@ -244,6 +256,7 @@ def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
eval_metric.reset()
actual_num_batch = 0
+ _check_data_names(eval_data, self.data_names)
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
@@ -364,6 +377,7 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
output_list = []
+ _check_data_names(eval_data, self.data_names)
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services