You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/11 05:15:02 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12285: Change the way NDArrayIter handle the last batch

sandeep-krishnamurthy closed pull request #12285: Change the way NDArrayIter handle the last batch
URL: https://github.com/apache/incubator-mxnet/pull/12285
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 8d8aeaca73e..1c005d57c4a 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -178,3 +178,4 @@ List of Contributors
 * [Aaron Markham](https://github.com/aaronmarkham)
 * [Sam Skalicky](https://github.com/samskalicky)
 * [Per Goncalves da Silva](https://github.com/perdasilva)
+* [Cheng-Che Lee](https://github.com/stu1130)
diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py
new file mode 100644
index 00000000000..5c5e2e68d84
--- /dev/null
+++ b/python/mxnet/io/__init__.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+""" Data iterators for common data formats and utility functions."""
+from __future__ import absolute_import
+
+from . import io
+from .io import *
+
+from . import utils
+from .utils import *
diff --git a/python/mxnet/io.py b/python/mxnet/io/io.py
similarity index 82%
rename from python/mxnet/io.py
rename to python/mxnet/io/io.py
index 884e9294741..2ae3e70045f 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io/io.py
@@ -17,30 +17,26 @@
 
 """Data iterators for common data formats."""
 from __future__ import absolute_import
-from collections import OrderedDict, namedtuple
+from collections import namedtuple
 
 import sys
 import ctypes
 import logging
 import threading
-try:
-    import h5py
-except ImportError:
-    h5py = None
 import numpy as np
-from .base import _LIB
-from .base import c_str_array, mx_uint, py_str
-from .base import DataIterHandle, NDArrayHandle
-from .base import mx_real_t
-from .base import check_call, build_param_doc as _build_param_doc
-from .ndarray import NDArray
-from .ndarray.sparse import CSRNDArray
-from .ndarray.sparse import array as sparse_array
-from .ndarray import _ndarray_cls
-from .ndarray import array
-from .ndarray import concatenate
-from .ndarray import arange
-from .ndarray.random import shuffle as random_shuffle
+
+from ..base import _LIB
+from ..base import c_str_array, mx_uint, py_str
+from ..base import DataIterHandle, NDArrayHandle
+from ..base import mx_real_t
+from ..base import check_call, build_param_doc as _build_param_doc
+from ..ndarray import NDArray
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray import _ndarray_cls
+from ..ndarray import array
+from ..ndarray import concat
+
+from .utils import init_data, has_instance, getdata_by_idx
 
 class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
     """DataDesc is used to store name, shape, type and layout
@@ -489,59 +485,6 @@ def getindex(self):
     def getpad(self):
         return self.current_batch.pad
 
-def _init_data(data, allow_empty, default_name):
-    """Convert data into canonical form."""
-    assert (data is not None) or allow_empty
-    if data is None:
-        data = []
-
-    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
-                  if h5py else (np.ndarray, NDArray)):
-        data = [data]
-    if isinstance(data, list):
-        if not allow_empty:
-            assert(len(data) > 0)
-        if len(data) == 1:
-            data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type
-        else:
-            data = OrderedDict( # pylint: disable=redefined-variable-type
-                [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
-    if not isinstance(data, dict):
-        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
-                "a list of them or dict with them as values")
-    for k, v in data.items():
-        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
-            try:
-                data[k] = array(v)
-            except:
-                raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) + \
-                                "should be NDArray, numpy.ndarray or h5py.Dataset")
-
-    return list(sorted(data.items()))
-
-def _has_instance(data, dtype):
-    """Return True if ``data`` has instance of ``dtype``.
-    This function is called after _init_data.
-    ``data`` is a list of (str, NDArray)"""
-    for item in data:
-        _, arr = item
-        if isinstance(arr, dtype):
-            return True
-    return False
-
-def _shuffle(data, idx):
-    """Shuffle the data."""
-    shuffle_data = []
-
-    for k, v in data:
-        if (isinstance(v, h5py.Dataset) if h5py else False):
-            shuffle_data.append((k, v))
-        elif isinstance(v, CSRNDArray):
-            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
-        else:
-            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
-
-    return shuffle_data
 
 class NDArrayIter(DataIter):
     """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset``
@@ -601,6 +544,22 @@ class NDArrayIter(DataIter):
     ...
     >>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
     3
+    >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
+    >>> batchidx = 0
+    >>> for batch in dataiter:
+    ...     batchidx += 1
+    ...
+    >>> batchidx # Remaining examples are rolled over to the next iteration.
+    3
+    >>> dataiter.reset()
+    >>> dataiter.next().data[0].asnumpy()
+    [[[ 36.  37.]
+      [ 38.  39.]]
+     [[ 0.  1.]
+      [ 2.  3.]]
+     [[ 4.  5.]
+      [ 6.  7.]]]
+    (3L, 2L, 2L)
 
     `NDArrayIter` also supports multiple input and labels.
 
@@ -633,8 +592,11 @@ class NDArrayIter(DataIter):
         Only supported if no h5py.Dataset inputs are used.
     last_batch_handle : str, optional
         How to handle the last batch. This parameter can be 'pad', 'discard' or
-        'roll_over'. 'roll_over' is intended for training and can cause problems
-        if used for prediction.
+        'roll_over'.
+        If 'pad', the last batch will be padded with data starting from the begining
+        If 'discard', the last batch will be discarded
+        If 'roll_over', the remaining elements will be rolled over to the next iteration and
+        note that it is intended for training and can cause problems if used for prediction.
     data_name : str, optional
         The data name.
     label_name : str, optional
@@ -645,36 +607,28 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
                  label_name='softmax_label'):
         super(NDArrayIter, self).__init__(batch_size)
 
-        self.data = _init_data(data, allow_empty=False, default_name=data_name)
-        self.label = _init_data(label, allow_empty=True, default_name=label_name)
+        self.data = init_data(data, allow_empty=False, default_name=data_name)
+        self.label = init_data(label, allow_empty=True, default_name=label_name)
 
-        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
+        if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and
                 (last_batch_handle != 'discard')):
             raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
                                       " with `last_batch_handle` set to `discard`.")
 
-        # shuffle data
-        if shuffle:
-            tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
-            self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
-            self.data = _shuffle(self.data, self.idx)
-            self.label = _shuffle(self.label, self.idx)
-        else:
-            self.idx = np.arange(self.data[0][1].shape[0])
-
-        # batching
-        if last_batch_handle == 'discard':
-            new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
-            self.idx = self.idx[:new_n]
+        self.idx = np.arange(self.data[0][1].shape[0])
+        self.shuffle = shuffle
+        self.last_batch_handle = last_batch_handle
+        self.batch_size = batch_size
+        self.cursor = -self.batch_size
+        self.num_data = self.idx.shape[0]
+        # shuffle
+        self.reset()
 
         self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
         self.num_source = len(self.data_list)
-        self.num_data = self.idx.shape[0]
-        assert self.num_data >= batch_size, \
-            "batch_size needs to be smaller than data size."
-        self.cursor = -batch_size
-        self.batch_size = batch_size
-        self.last_batch_handle = last_batch_handle
+        # used for 'roll_over'
+        self._cache_data = None
+        self._cache_label = None
 
     @property
     def provide_data(self):
@@ -694,74 +648,126 @@ def provide_label(self):
 
     def hard_reset(self):
         """Ignore roll over data and set to start."""
+        if self.shuffle:
+            self._shuffle_data()
         self.cursor = -self.batch_size
+        self._cache_data = None
+        self._cache_label = None
 
     def reset(self):
-        if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
-            self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
+        """Resets the iterator to the beginning of the data."""
+        if self.shuffle:
+            self._shuffle_data()
+        # the range below indicate the last batch
+        if self.last_batch_handle == 'roll_over' and \
+            self.num_data - self.batch_size < self.cursor < self.num_data:
+            # (self.cursor - self.num_data) represents the data we have for the last batch
+            self.cursor = self.cursor - self.num_data - self.batch_size
         else:
             self.cursor = -self.batch_size
 
     def iter_next(self):
+        """Increments the coursor by batch_size for next batch
+        and check current cursor if it exceed the number of data points."""
         self.cursor += self.batch_size
         return self.cursor < self.num_data
 
     def next(self):
-        if self.iter_next():
-            return DataBatch(data=self.getdata(), label=self.getlabel(), \
-                    pad=self.getpad(), index=None)
-        else:
+        """Returns the next batch of data."""
+        if not self.iter_next():
+            raise StopIteration
+        data = self.getdata()
+        label = self.getlabel()
+        # iter should stop when last batch is not complete
+        if data[0].shape[0] != self.batch_size:
+        # in this case, cache it for next epoch
+            self._cache_data = data
+            self._cache_label = label
             raise StopIteration
+        return DataBatch(data=data, label=label, \
+            pad=self.getpad(), index=None)
+
+    def _getdata(self, data_source, start=None, end=None):
+        """Load data from underlying arrays."""
+        assert start is not None or end is not None, 'should at least specify start or end'
+        start = start if start is not None else 0
+        end = end if end is not None else data_source[0][1].shape[0]
+        s = slice(start, end)
+        return [
+            x[1][s]
+            if isinstance(x[1], (np.ndarray, NDArray)) else
+            # h5py (only supports indices in increasing order)
+            array(x[1][sorted(self.idx[s])][[
+                list(self.idx[s]).index(i)
+                for i in sorted(self.idx[s])
+            ]]) for x in data_source
+        ]
 
-    def _getdata(self, data_source):
+    def _concat(self, first_data, second_data):
+        """Helper function to concat two NDArrays."""
+        return [
+            concat(first_data[0], second_data[0], dim=0)
+        ]
+
+    def _batchify(self, data_source):
         """Load data from underlying arrays, internal use only."""
-        assert(self.cursor < self.num_data), "DataIter needs reset."
-        if self.cursor + self.batch_size <= self.num_data:
-            return [
-                # np.ndarray or NDArray case
-                x[1][self.cursor:self.cursor + self.batch_size]
-                if isinstance(x[1], (np.ndarray, NDArray)) else
-                # h5py (only supports indices in increasing order)
-                array(x[1][sorted(self.idx[
-                    self.cursor:self.cursor + self.batch_size])][[
-                        list(self.idx[self.cursor:
-                                      self.cursor + self.batch_size]).index(i)
-                        for i in sorted(self.idx[
-                            self.cursor:self.cursor + self.batch_size])
-                    ]]) for x in data_source
-            ]
-        else:
+        assert self.cursor < self.num_data, 'DataIter needs reset.'
+        # first batch of next epoch with 'roll_over'
+        if self.last_batch_handle == 'roll_over' and \
+            -self.batch_size < self.cursor < 0:
+            assert self._cache_data is not None or self._cache_label is not None, \
+                'next epoch should have cached data'
+            cache_data = self._cache_data if self._cache_data is not None else self._cache_label
+            second_data = self._getdata(
+                data_source, end=self.cursor + self.batch_size)
+            if self._cache_data is not None:
+                self._cache_data = None
+            else:
+                self._cache_label = None
+            return self._concat(cache_data, second_data)
+        # last batch with 'pad'
+        elif self.last_batch_handle == 'pad' and \
+            self.cursor + self.batch_size > self.num_data:
             pad = self.batch_size - self.num_data + self.cursor
-            return [
-                # np.ndarray or NDArray case
-                concatenate([x[1][self.cursor:], x[1][:pad]])
-                if isinstance(x[1], (np.ndarray, NDArray)) else
-                # h5py (only supports indices in increasing order)
-                concatenate([
-                    array(x[1][sorted(self.idx[self.cursor:])][[
-                        list(self.idx[self.cursor:]).index(i)
-                        for i in sorted(self.idx[self.cursor:])
-                    ]]),
-                    array(x[1][sorted(self.idx[:pad])][[
-                        list(self.idx[:pad]).index(i)
-                        for i in sorted(self.idx[:pad])
-                    ]])
-                ]) for x in data_source
-            ]
+            first_data = self._getdata(data_source, start=self.cursor)
+            second_data = self._getdata(data_source, end=pad)
+            return self._concat(first_data, second_data)
+        # normal case
+        else:
+            if self.cursor + self.batch_size < self.num_data:
+                end_idx = self.cursor + self.batch_size
+            # get incomplete last batch
+            else:
+                end_idx = self.num_data
+            return self._getdata(data_source, self.cursor, end_idx)
 
     def getdata(self):
-        return self._getdata(self.data)
+        """Get data."""
+        return self._batchify(self.data)
 
     def getlabel(self):
-        return self._getdata(self.label)
+        """Get label."""
+        return self._batchify(self.label)
 
     def getpad(self):
+        """Get pad value of DataBatch."""
         if self.last_batch_handle == 'pad' and \
            self.cursor + self.batch_size > self.num_data:
             return self.cursor + self.batch_size - self.num_data
+        # check the first batch
+        elif self.last_batch_handle == 'roll_over' and \
+            -self.batch_size < self.cursor < 0:
+            return -self.cursor
         else:
             return 0
 
+    def _shuffle_data(self):
+        """Shuffle the data."""
+        # shuffle index
+        np.random.shuffle(self.idx)
+        # get the data by corresponding index
+        self.data = getdata_by_idx(self.data, self.idx)
+        self.label = getdata_by_idx(self.label, self.idx)
 
 class MXDataIter(DataIter):
     """A python wrapper a C++ data iterator.
@@ -773,7 +779,7 @@ class MXDataIter(DataIter):
     underlying C++ data iterators.
 
     Usually you don't need to interact with `MXDataIter` directly unless you are
-    implementing your own data iterators in C++. To do that, please refer to
+    implementing your own data iterators in C+ +. To do that, please refer to
     examples under the `src/io` folder.
 
     Parameters
diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py
new file mode 100644
index 00000000000..872e6410d7d
--- /dev/null
+++ b/python/mxnet/io/utils.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""utility functions for io.py"""
+from collections import OrderedDict
+
+import numpy as np
+try:
+    import h5py
+except ImportError:
+    h5py = None
+
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray.sparse import array as sparse_array
+from ..ndarray import NDArray
+from ..ndarray import array
+
+def init_data(data, allow_empty, default_name):
+    """Convert data into canonical form."""
+    assert (data is not None) or allow_empty
+    if data is None:
+        data = []
+
+    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
+                  if h5py else (np.ndarray, NDArray)):
+        data = [data]
+    if isinstance(data, list):
+        if not allow_empty:
+            assert(len(data) > 0)
+        if len(data) == 1:
+            data = OrderedDict([(default_name, data[0])])  # pylint: disable=redefined-variable-type
+        else:
+            data = OrderedDict(  # pylint: disable=redefined-variable-type
+                [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
+    if not isinstance(data, dict):
+        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
+                        "a list of them or dict with them as values")
+    for k, v in data.items():
+        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
+            try:
+                data[k] = array(v)
+            except:
+                raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) +
+                                "should be NDArray, numpy.ndarray or h5py.Dataset")
+
+    return list(sorted(data.items()))
+
+
+def has_instance(data, dtype):
+    """Return True if ``data`` has instance of ``dtype``.
+    This function is called after _init_data.
+    ``data`` is a list of (str, NDArray)"""
+    for item in data:
+        _, arr = item
+        if isinstance(arr, dtype):
+            return True
+    return False
+
+
+def getdata_by_idx(data, idx):
+    """Shuffle the data."""
+    shuffle_data = []
+
+    for k, v in data:
+        if (isinstance(v, h5py.Dataset) if h5py else False):
+            shuffle_data.append((k, v))
+        elif isinstance(v, CSRNDArray):
+            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
+        else:
+            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
+
+    return shuffle_data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 4dfa69cc105..ae686261b81 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -88,80 +88,88 @@ def test_Cifar10Rec():
         assert(labelcount[i] == 5000)
 
 
-def test_NDArrayIter():
+def _init_NDArrayIter_data():
     data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
+    labels = np.ones([1000, 1])
     for i in range(1000):
         data[i] = i / 100
-        label[i] = i / 100
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, True, last_batch_handle='pad')
-    batchidx = 0
+        labels[i] = i / 100
+    return data, labels
+
+
+def _test_last_batch_handle(data, labels):
+    # Test the three parameters 'pad', 'discard', 'roll_over'
+    last_batch_handle_list = ['pad', 'discard' , 'roll_over']
+    labelcount_list = [(124, 100), (100, 96), (100, 96)]
+    batch_count_list = [8, 7, 7]
+    
+    for idx in range(len(last_batch_handle_list)):
+        dataiter = mx.io.NDArrayIter(
+            data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx])
+        batch_count = 0
+        labelcount = [0 for i in range(10)]
+        for batch in dataiter:
+            label = batch.label[0].asnumpy().flatten()
+            # check data if it matches corresponding labels
+            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx]
+            for i in range(label.shape[0]):
+                labelcount[int(label[i])] += 1
+            # keep the last batch of 'pad' to be used later 
+            # to test first batch of roll_over in second iteration
+            batch_count += 1
+            if last_batch_handle_list[idx] == 'pad' and \
+                batch_count == 8:
+                cache = batch.data[0].asnumpy()
+        # check if batchifying functionality work properly
+        assert labelcount[0] == labelcount_list[idx][0], last_batch_handle_list[idx]
+        assert labelcount[8] == labelcount_list[idx][1], last_batch_handle_list[idx]
+        assert batch_count == batch_count_list[idx]
+    # roll_over option
+    dataiter.reset()
+    assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
+
+
+def _test_shuffle(data, labels):
+    dataiter = mx.io.NDArrayIter(data, labels, 1, False)
+    batch_list = []
     for batch in dataiter:
-        batchidx += 1
-    assert(batchidx == 8)
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, False, last_batch_handle='pad')
-    batchidx = 0
-    labelcount = [0 for i in range(10)]
+        # cache the original data
+        batch_list.append(batch.data[0].asnumpy())
+    dataiter = mx.io.NDArrayIter(data, labels, 1, True)
+    idx_list = dataiter.idx
+    i = 0
     for batch in dataiter:
-        label = batch.label[0].asnumpy().flatten()
-        assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-        for i in range(label.shape[0]):
-            labelcount[int(label[i])] += 1
+        # check if each data point have been shuffled to corresponding positions
+        assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
+        i += 1
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def test_NDArrayIter():
+    data, labels = _init_NDArrayIter_data()
+    _test_last_batch_handle(data, labels)
+    _test_shuffle(data, labels)
 
 
 def test_NDArrayIter_h5py():
     if not h5py:
         return
 
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
-    for i in range(1000):
-        data[i] = i / 100
-        label[i] = i / 100
+    data, labels = _init_NDArrayIter_data()
 
     try:
-        os.remove("ndarraytest.h5")
+        os.remove('ndarraytest.h5')
     except OSError:
         pass
-    with h5py.File("ndarraytest.h5") as f:
-        f.create_dataset("data", data=data)
-        f.create_dataset("label", data=label)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, True, last_batch_handle='pad')
-        batchidx = 0
-        for batch in dataiter:
-            batchidx += 1
-        assert(batchidx == 8)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, False, last_batch_handle='pad')
-        labelcount = [0 for i in range(10)]
-        for batch in dataiter:
-            label = batch.label[0].asnumpy().flatten()
-            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-            for i in range(label.shape[0]):
-                labelcount[int(label[i])] += 1
+    with h5py.File('ndarraytest.h5') as f:
+        f.create_dataset('data', data=data)
+        f.create_dataset('label', data=labels)
 
+        _test_last_batch_handle(f['data'], f['label'])
     try:
         os.remove("ndarraytest.h5")
     except OSError:
         pass
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
-
 
 def test_NDArrayIter_csr():
     # creating toy data
@@ -182,12 +190,20 @@ def test_NDArrayIter_csr():
                      {'data': train_data}, dns, batch_size)
     except ImportError:
         pass
+    # scipy.sparse.csr_matrix with shuffle
+    num_batch = 0
+    csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size,
+                                      shuffle=True, last_batch_handle='discard'))
+    for _ in csr_iter:
+        num_batch += 1
+
+    assert(num_batch == num_rows // batch_size)
 
     # CSRNDArray with shuffle
     csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size,
                                       shuffle=True, last_batch_handle='discard'))
     num_batch = 0
-    for batch in csr_iter:
+    for _ in csr_iter:
         num_batch += 1
 
     assert(num_batch == num_rows // batch_size)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services