You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/13 17:26:21 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12975: Implemented a regression unit test for #11793

sandeep-krishnamurthy closed pull request #12975: Implemented a regression unit test for #11793
URL: https://github.com/apache/incubator-mxnet/pull/12975
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 39fcd81642d..7347723a39c 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -812,6 +812,65 @@ def test_forward_types():
     assert mod.predict(data1).shape == (1, 10)
 
 
+def test_reference_single_batch_during_fit():
+    """
+    When using C++-based iterators, it's important that only a single batch is referenced at a time. Because C++
+    iterators are exposed to the Python code through a C API, there is no concept of reference counting. Hence,
+    typically C++ iterators will deallocate a batch when next() is called on them. So, we need to make sure the Python
+    code only references a single batch at a time, otherwise the Python code will attempt to access freed memory,
+    resulting in either (a) garbage accuracy or (b) a segmentation fault.
+    """
+    current_batch_i = None
+
+    class MockBatch(object):
+        def __init__(self, i):
+            self.i = i
+
+        @property
+        def label(self):
+            global current_batch_i
+            assert self.i == current_batch_i
+
+    class MockTrainData(object):
+        def __init__(self, batches):
+            self._i = 0
+            self._batches = batches
+            self.provide_data = None
+            self.provide_label = None
+            self.reset = lambda: None
+
+        def __iter__(self):
+            self._i = 0
+            return self
+
+        def __next__(self):
+            global current_batch_i
+
+            if self._i < self._batches:
+                current_batch_i = self._i
+                self._i += 1
+                return MockBatch(current_batch_i)
+            raise StopIteration
+
+        def next(self):
+            return self.__next__()
+
+    mod = mx.mod.BaseModule()
+
+    def empty_fn(*args, **kwargs):
+        pass
+    mod.bind = empty_fn
+    mod.init_params = empty_fn
+    mod.init_optimizer = empty_fn
+    mod.forward = empty_fn
+    mod.backward = empty_fn
+    mod.update = empty_fn
+    mod.update_metric = empty_fn
+    mod.get_params = lambda: (None, None)
+
+    train_data = MockTrainData(batches=2)
+    mod.fit(train_data, num_epoch=1)
+
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services