You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/06 20:27:58 UTC

[GitHub] spidyDev closed pull request #10425: [MXNET-175] Raise user warning on mismatch between module data_names and data iter names

spidyDev closed pull request #10425: [MXNET-175] Raise user warning on mismatch between module data_names and data iter names
URL: https://github.com/apache/incubator-mxnet/pull/10425
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 884e9294741..3e25a2151ce 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -492,6 +492,7 @@ def getpad(self):
 def _init_data(data, allow_empty, default_name):
     """Convert data into canonical form."""
     assert (data is not None) or allow_empty
+    renamed_data = False
     if data is None:
         data = []
 
@@ -506,6 +507,9 @@ def _init_data(data, allow_empty, default_name):
         else:
             data = OrderedDict( # pylint: disable=redefined-variable-type
                 [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
+        # This is used to identify if the  data was originally a list, and was
+        # modified to a OrderedDict. Used when doing forward pass.
+        renamed_data = True
     if not isinstance(data, dict):
         raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
                 "a list of them or dict with them as values")
@@ -517,7 +521,7 @@ def _init_data(data, allow_empty, default_name):
                 raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) + \
                                 "should be NDArray, numpy.ndarray or h5py.Dataset")
 
-    return list(sorted(data.items()))
+    return list(sorted(data.items())), renamed_data
 
 def _has_instance(data, dtype):
     """Return True if ``data`` has instance of ``dtype``.
@@ -645,8 +649,8 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
                  label_name='softmax_label'):
         super(NDArrayIter, self).__init__(batch_size)
 
-        self.data = _init_data(data, allow_empty=False, default_name=data_name)
-        self.label = _init_data(label, allow_empty=True, default_name=label_name)
+        self.data, self.renamed_data = _init_data(data, allow_empty=False, default_name=data_name)
+        self.label, _ = _init_data(label, allow_empty=True, default_name=label_name)
 
         if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
                 (last_batch_handle != 'discard')):
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index c03f8e73cda..ed836d29c02 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -25,6 +25,7 @@
 
 from .. import metric
 from .. import ndarray
+from .. import io
 
 from ..context import cpu
 from ..model import BatchEndParam
@@ -78,6 +79,17 @@ def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
     return data_shapes, label_shapes
 
 
+def _check_data_names(eval_data, data_names):
+    """ Check if iterator data names match the data names provided in module"""
+    if isinstance(eval_data, io.NDArrayIter) and \
+            not eval_data.renamed_data and \
+            len(eval_data.data) and isinstance(eval_data.data[0], tuple) and \
+            dict(eval_data.data).keys() != data_names:
+        msg = "Data provided in data_names don't match names specified by iterator" \
+              " (%s vs. %s)"%(str(data_names), str(dict(eval_data.data).keys()))
+        warnings.warn(msg)
+
+
 class BaseModule(object):
     """The base class of a module.
 
@@ -244,6 +256,7 @@ def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
         eval_metric.reset()
         actual_num_batch = 0
 
+        _check_data_names(eval_data, self.data_names)
         for nbatch, eval_batch in enumerate(eval_data):
             if num_batch is not None and nbatch == num_batch:
                 break
@@ -364,6 +377,7 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
 
         output_list = []
 
+        _check_data_names(eval_data, self.data_names)
         for nbatch, eval_batch in enumerate(eval_data):
             if num_batch is not None and nbatch == num_batch:
                 break


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services