You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/10/29 17:48:37 UTC

[GitHub] [incubator-mxnet] szha commented on a change in pull request #20692: [API STD][SEARCH FUNC] Add keepdims=False to argmax/argmin

szha commented on a change in pull request #20692:
URL: https://github.com/apache/incubator-mxnet/pull/20692#discussion_r739428844



##########
File path: tests/python/unittest/test_numpy_op.py
##########
@@ -4437,73 +4437,84 @@ def GetDimSize(shp, axis):
 
 
 @use_np
-def test_np_argmin_argmax():
-    workloads = [
-        ((), 0, False),
-        ((), -1, False),
-        ((), 1, True),
-        ((5, 3), None, False),
-        ((5, 3), -1, False),
-        ((5, 3), 1, False),
-        ((5, 3), 3, True),
-        ((5, 0, 3), 0, False),
-        ((5, 0, 3), -1, False),
-        ((5, 0, 3), None, True),
-        ((5, 0, 3), 1, True),
-        ((3, 5, 7), None, False),
-        ((3, 5, 7), 0, False),
-        ((3, 5, 7), 1, False),
-        ((3, 5, 7), 2, False),
-        ((3, 5, 7, 9, 11), -3, False),
-    ]
-    dtypes = ['float16', 'float32', 'float64', 'bool', 'int32']
-    ops = ['argmin', 'argmax']
-
+@pytest.mark.parametrize('shape,axis,throw_exception', [
+    ((), 0, False),
+    ((), -1, False),
+    ((), 1, True),
+    ((5, 3), None, False),
+    ((5, 3), -1, False),
+    ((5, 3), 1, False),
+    ((5, 3), 3, True),
+    ((5, 0, 3), 0, False),
+    ((5, 0, 3), -1, False),
+    ((5, 0, 3), None, True),
+    ((5, 0, 3), 1, True),
+    ((3, 5, 7), None, False),
+    ((3, 5, 7), 0, False),
+    ((3, 5, 7), 1, False),
+    ((3, 5, 7), 2, False),
+    ((3, 5, 7, 9, 11), -3, False),
+])
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'bool', 'int32'])
+@pytest.mark.parametrize('op_name', ['argmin', 'argmax'])
+@pytest.mark.parametrize('keepdims', [True, False])
+def test_np_argmin_argmax(shape, axis, throw_exception, dtype, op_name, keepdims):
     class TestArgExtreme(HybridBlock):
-        def __init__(self, op_name, axis=None):
+        def __init__(self, op_name, axis=None, keepdims=False):
             super(TestArgExtreme, self).__init__()
             self._op_name = op_name
             self._axis = axis
+            self.keepdims = keepdims
 
         def forward(self, x):
-            return getattr(x, self._op_name)(self._axis)
-
-    for op_name in ops:
-        for shape, axis, throw_exception in workloads:
-            for dtype in dtypes:
-                a = np.random.uniform(low=0, high=100, size=shape).astype(dtype)
-                if throw_exception:
-                    # Cannot use assert_exception because sometimes the main thread
-                    # proceeds to `assert False` before the exception is thrown
-                    # in the worker thread. Have to use mx.nd.waitall() here
-                    # to block the main thread.
-                    try:
-                        getattr(np, op_name)(a, axis)
-                        mx.nd.waitall()
-                        assert False
-                    except mx.MXNetError:
-                        pass
-                else:
-                    mx_ret = getattr(np, op_name)(a, axis=axis)
-                    np_ret = getattr(onp, op_name)(a.asnumpy(), axis=axis)
-                    assert mx_ret.dtype == np_ret.dtype
-                    assert same(mx_ret.asnumpy(), np_ret)
+            return getattr(x, self._op_name)(self._axis, keepdims=self.keepdims)
+
+    a = np.random.uniform(low=0, high=100, size=shape).astype(dtype)
+    if throw_exception:
+        # Cannot use assert_exception because sometimes the main thread
+        # proceeds to `assert False` before the exception is thrown
+        # in the worker thread. Have to use mx.nd.waitall() here
+        # to block the main thread.
+        try:
+            getattr(np, op_name)(a, axis)
+            mx.nd.waitall()
+            assert False
+        except mx.MXNetError:
+            pass

Review comment:
       This can be simplified https://docs.pytest.org/en/6.2.x/assert.html#assertions-about-expected-exceptions




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org