You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/19 17:51:32 UTC

[GitHub] zhreshold commented on a change in pull request #12545: Change the way NDArrayIter handle the last batch

zhreshold commented on a change in pull request #12545: Change the way NDArrayIter handle the last batch
URL: https://github.com/apache/incubator-mxnet/pull/12545#discussion_r218902059
 
 

 ##########
 File path: tests/python/unittest/test_io.py
 ##########
 @@ -88,79 +89,122 @@ def test_Cifar10Rec():
         assert(labelcount[i] == 5000)
 
 
-def test_NDArrayIter():
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
+def _init_NDArrayIter_data(data_type):
+    if data_type == 'NDArray':
+        data = nd.ones((1000, 2, 2))
+        labels = nd.ones((1000, 2, 2))
+    else:
+        data = np.ones([1000, 2, 2])
+        labels = np.ones([1000, 1])
     for i in range(1000):
         data[i] = i / 100
-        label[i] = i / 100
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, True, last_batch_handle='pad')
-    batchidx = 0
+        labels[i] = i / 100
+    return data, labels
+
+
+def _test_last_batch_handle(data, labels=None):
+    # Test the three parameters 'pad', 'discard', 'roll_over'
+    last_batch_handle_list = ['pad', 'discard' , 'roll_over']
+    if labels is not None and len(labels) != 0:
+        labelcount_list = [(124, 100), (100, 96), (100, 96)]
+    batch_count_list = [8, 7, 7]
+    
+    for idx in range(len(last_batch_handle_list)):
+        dataiter = mx.io.NDArrayIter(
+            data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx])
+        batch_count = 0
+        if labels is not None and len(labels) != 0:
+            labelcount = [0 for i in range(10)]
+        for batch in dataiter:
+            if len(data) == 2:
+                assert len(batch.data) == 2
+            if labels is not None and len(labels) != 0:
+                label = batch.label[0].asnumpy().flatten()
+                # check data if it matches corresponding labels
+                assert ((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx]
+                for i in range(label.shape[0]):
+                    labelcount[int(label[i])] += 1
+            else:
+                assert not batch.label, 'label is not empty list'
+            # keep the last batch of 'pad' to be used later 
+            # to test first batch of roll_over in second iteration
+            batch_count += 1
+            if last_batch_handle_list[idx] == 'pad' and \
+                batch_count == 8:
+                cache = batch.data[0].asnumpy()
+        # check if batchifying functionality work properly
+        if labels is not None and len(labels) != 0:
+            assert labelcount[0] == labelcount_list[idx][0], last_batch_handle_list[idx]
+            assert labelcount[8] == labelcount_list[idx][1], last_batch_handle_list[idx]
+        assert batch_count == batch_count_list[idx]
+    # roll_over option
+    dataiter.reset()
+    assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
+
+
+def _test_shuffle(data, labels=None):
+    dataiter = mx.io.NDArrayIter(data, labels, 1, False)
+    batch_list = []
     for batch in dataiter:
-        batchidx += 1
-    assert(batchidx == 8)
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, False, last_batch_handle='pad')
-    batchidx = 0
-    labelcount = [0 for i in range(10)]
+        # cache the original data
+        batch_list.append(batch.data[0].asnumpy())
+    dataiter = mx.io.NDArrayIter(data, labels, 1, True)
+    idx_list = dataiter.idx
+    i = 0
     for batch in dataiter:
-        label = batch.label[0].asnumpy().flatten()
-        assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-        for i in range(label.shape[0]):
-            labelcount[int(label[i])] += 1
+        # check if each data point have been shuffled to corresponding positions
+        assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
+        i += 1
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def test_NDArrayIter():
+    dtype_list = ['NDArray', 'ndarray']
+    for dtype in dtype_list:
+        data, labels = _init_NDArrayIter_data(dtype_list)
+        _test_last_batch_handle(data, labels)
+        _test_last_batch_handle([data, data], labels)
+        _test_last_batch_handle(data, [])
+        _test_last_batch_handle(data)
+        _test_shuffle(data, labels)
+        _test_shuffle([data, data], labels)
+        _test_shuffle(data, [])
+        _test_shuffle(data)
 
 
 def test_NDArrayIter_h5py():
     if not h5py:
         return
 
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
-    for i in range(1000):
-        data[i] = i / 100
-        label[i] = i / 100
+    data, labels = _init_NDArrayIter_data('ndarray')
 
     try:
-        os.remove("ndarraytest.h5")
+        os.remove('ndarraytest.h5')
     except OSError:
         pass
-    with h5py.File("ndarraytest.h5") as f:
-        f.create_dataset("data", data=data)
-        f.create_dataset("label", data=label)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, True, last_batch_handle='pad')
-        batchidx = 0
-        for batch in dataiter:
-            batchidx += 1
-        assert(batchidx == 8)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, False, last_batch_handle='pad')
-        labelcount = [0 for i in range(10)]
-        for batch in dataiter:
-            label = batch.label[0].asnumpy().flatten()
-            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-            for i in range(label.shape[0]):
-                labelcount[int(label[i])] += 1
-
+    with h5py.File('ndarraytest.h5') as f:
+        f.create_dataset('data', data=data)
+        f.create_dataset('label', data=labels)
+        
+        _test_last_batch_handle(f['data'], f['label'])
+        _test_last_batch_handle(f['data'], [])
+        _test_last_batch_handle(f['data'])
     try:
         os.remove("ndarraytest.h5")
     except OSError:
         pass
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, csr_iter_None, num_rows, batch_size):
 
 Review comment:
   does this cover all failure case last time? Just wanna be super careful  with all existing use cases. We can contact those guys if we need more test cases. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services