You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/08/23 17:39:54 UTC

[GitHub] reminisce commented on a change in pull request #7572: Fix import error of broadcast max, min, mod in ndarray.py

reminisce commented on a change in pull request #7572: Fix import error of broadcast max, min, mod in ndarray.py
URL: https://github.com/apache/incubator-mxnet/pull/7572#discussion_r134819794
 
 

 ##########
 File path: tests/python/unittest/test_operator.py
 ##########
 @@ -1177,51 +1221,65 @@ def test_bneq(a, b):
     test_bpow(a, b)
     test_bneq(a, b)
 
+
 def test_broadcast_binary_op():
     a = mx.sym.Variable('a')
     b = mx.sym.Variable('b')
 
     def test_bplus(a, b):
         c = mx.sym.broadcast_plus(a, b)
-        check_binary_op_forward(c, lambda a, b: a + b, gen_broadcast_data)
+        check_binary_op_forward(c, lambda a, b: a + b, gen_broadcast_data, mx_nd_func=mx.nd.add)
         check_binary_op_backward(c, lambda g_out, a, b: (g_out, g_out), gen_broadcast_data)
 
     def test_bminus(a, b):
         c = mx.sym.broadcast_minus(a, b)
-        check_binary_op_forward(c, lambda a, b: a - b, gen_broadcast_data)
+        check_binary_op_forward(c, lambda a, b: a - b, gen_broadcast_data, mx_nd_func=mx.nd.subtract)
         check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out), gen_broadcast_data)
 
     def test_bmul(a, b):
         c = mx.sym.broadcast_mul(a, b)
-        check_binary_op_forward(c, lambda a, b: a * b, gen_broadcast_data)
+        check_binary_op_forward(c, lambda a, b: a * b, gen_broadcast_data, mx_nd_func=mx.nd.multiply)
         check_binary_op_backward(c, lambda g_out, a, b: (g_out * b, g_out * a), gen_broadcast_data)
 
     def test_bdiv(a, b):
         c = mx.sym.broadcast_div(a, b)
-        check_binary_op_forward(c, lambda a, b: a / b, gen_broadcast_data)
+        check_binary_op_forward(c, lambda a, b: a / b, gen_broadcast_data, mx_nd_func=mx.nd.divide)
         check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_broadcast_data)
 
     def test_bmod(a, b):
         c = mx.sym.broadcast_mod(a, b)
-        check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data, atol=1)
+        check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data, atol=1, mx_nd_func=mx.nd.modulo)
         check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out * (a // b)), gen_broadcast_data, atol=1)
 
     def test_bmod_int(a, b):
         c = mx.sym.broadcast_mod(mx.sym.cast(a, dtype='int32'), mx.sym.cast(b, dtype='int32'))
-        check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data_int)
+        check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data_int, mx_nd_func=mx.nd.modulo)
         check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_broadcast_data_int)
 
     def test_bpow(a, b):
         c = mx.sym.broadcast_power(a, b)
-        check_binary_op_forward(c, lambda a, b: a ** b, gen_broadcast_data)
+        check_binary_op_forward(c, lambda a, b: a ** b, gen_broadcast_data, mx_nd_func=mx.nd.power)
         check_binary_op_backward(c, lambda g_out, a, b: (g_out * a **(b - 1) * b,
                                         g_out * a ** b * np.log(a)), gen_broadcast_data)
 
     def test_bequal(a, b):
         c = mx.sym.broadcast_equal(a, b)
-        check_binary_op_forward(c, lambda a, b: (a == b).astype(a.dtype), gen_broadcast_data_int)
+        check_binary_op_forward(c, lambda a, b: (a == b).astype(a.dtype), gen_broadcast_data_int,
+                                mx_nd_func=mx.nd.equal)
         check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_broadcast_data_int)
 
+    def test_bmax(a, b):
+        c = mx.sym.broadcast_maximum(a, b)
+        check_binary_op_forward(c, lambda x, y: np.maximum(x, y), gen_broadcast_data, mx_nd_func=mx.nd.maximum)
+        # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+        check_numeric_gradient(c, gen_broadcast_data(idx=200), rtol=1e-2, atol=1e-3)
 
 Review comment:
   It's too complicated to implement broadcast_maximum and broadcast_minimum backward functions in python. Therefore, I used check_numeric_gradient to verify the backward passes for these two ops.
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services