You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/25 21:11:06 UTC

[tvm] branch main updated: [Relay][Training] Add more gradients (#7323)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f75cff  [Relay][Training] Add more gradients (#7323)
6f75cff is described below

commit 6f75cffb64f20e72a2fad425ce58d0fd32c0d4c8
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Mon Jan 25 13:10:48 2021 -0800

    [Relay][Training] Add more gradients (#7323)
    
    * add more gradients
    
    * add documentation
---
 python/tvm/relay/op/_tensor_grad.py       | 54 +++++++++++++++++++++++++++----
 tests/python/relay/test_op_grad_level1.py |  8 +++++
 tests/python/relay/test_op_grad_level3.py |  7 ++++
 tests/python/relay/test_op_grad_level4.py | 37 ++++++++++++++++++++-
 4 files changed, 99 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index 9c84411..c9a20a3 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -357,16 +357,24 @@ def global_avg_pool2d_grad(orig, grad):
     return [pool_grad]
 
 
-# not implemented, this is only for testing.
 @register_gradient("concatenate")
 def concatenate_grad(orig, grad):
+    """
+    Returns the gradient of concatenate, which is just the downstream gradient
+    split across the inputs.
+    """
     assert len(orig.args) == 1
     t = orig.args[0]
-    x = TupleGetItem(t, 0)
-    y = TupleGetItem(t, 1)
-    # Assume only two element in tuple rn.
-    # In the real implementation, concatenate_grad probably need to be implemented by an operator.
-    return [Tuple([zeros_like(x), zeros_like(y)])]
+
+    # calculate split indices. TODO(@altanh): support Any?
+    axis_dims = [ty.shape[orig.attrs.axis] for ty in t.checked_type.fields]
+    splits, cumsum = [], 0
+    for dim in axis_dims[:-1]:
+        cumsum += dim
+        splits.append(cumsum)
+
+    grads = split(grad, tuple(splits), axis=orig.attrs.axis).tuple_value
+    return [grads]
 
 
 @register_gradient("nn.conv2d")
@@ -808,5 +816,39 @@ def arange_grad(orig, grad):
 
 @register_gradient("gather_nd")
 def gather_nd_grad(orig, grad):
+    """
+    Returns the gradient of gather_nd, which is simply scatter_nd.
+    """
     data, indices = orig.args
     return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
+
+
+@register_gradient("reshape_like")
+def reshape_like_grad(orig, grad):
+    """
+    Returns the gradient of reshape_like.
+    """
+    data, shape_like = orig.args
+    return [reshape_like(grad, data), zeros_like(shape_like)]
+
+
+@register_gradient("where")
+def where_grad(orig, grad):
+    """
+    Returns the gradient of where.
+    """
+    cond, x, y = orig.args
+    g_zeros = zeros_like(grad)
+
+    grad_x = collapse_sum_like(where(cond, grad, g_zeros), x)
+    grad_y = collapse_sum_like(where(cond, g_zeros, grad), y)
+
+    return [zeros_like(cond), grad_x, grad_y]
+
+
+@register_gradient("less_equal")
+def less_equal_grad(orig, grad):
+    """
+    Returns the gradient of less_equal.
+    """
+    return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py
index cac07c4..a79be86 100644
--- a/tests/python/relay/test_op_grad_level1.py
+++ b/tests/python/relay/test_op_grad_level1.py
@@ -150,5 +150,13 @@ def test_expand_dims_grad():
     check_grad(fwd_func)
 
 
+def test_concatenate_grad():
+    x = relay.var("x", shape=(2, 2, 5))
+    y = relay.var("y", shape=(2, 1, 5))
+    z = relay.var("z", shape=(2, 4, 5))
+    fwd_func = relay.Function([x, y, z], relay.concatenate([x, y, z], axis=1))
+    check_grad(fwd_func)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 98ff62e..0c89aa7 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -126,5 +126,12 @@ def test_gather_nd_grad():
     check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[data_np])
 
 
+def test_reshape_like_grad():
+    data = relay.var("data", shape=(2, 3, 4), dtype="float32")
+    shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
+    fwd_func = relay.Function([data, shape_like], relay.reshape_like(data, shape_like))
+    check_grad(fwd_func)
+
+
 if __name__ == "__main__":
     pytest.main()
diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py
index d479221..0f73e89 100644
--- a/tests/python/relay/test_op_grad_level4.py
+++ b/tests/python/relay/test_op_grad_level4.py
@@ -15,8 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+import numpy as np
 from tvm import relay
-from tvm.relay.testing import check_grad
+from tvm.relay.testing import check_grad, _np_randn_from_type
 
 
 def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=False):
@@ -51,5 +52,39 @@ def test_max_grad():
     verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)
 
 
+def test_where_grad():
+    cond_type = relay.TensorType((2, 3, 4), "int32")
+    lhs_type = relay.TensorType((1, 3, 4), "float32")
+    rhs_type = relay.TensorType((2, 1, 4), "float32")
+    inputs = [
+        np.random.randint(2, size=cond_type.concrete_shape, dtype=cond_type.dtype),
+        _np_randn_from_type(lhs_type, scale=1e-5),
+        _np_randn_from_type(rhs_type, scale=1e-5),
+    ]
+
+    cond = relay.var("cond", type_annotation=cond_type)
+    lhs = relay.var("lhs", type_annotation=lhs_type)
+    rhs = relay.var("rhs", type_annotation=rhs_type)
+    fwd_func = relay.Function([cond, lhs, rhs], relay.where(cond, lhs, rhs))
+    check_grad(fwd_func, inputs=inputs, test_inputs=inputs[1:])
+
+
+def test_less_equal_grad():
+    x_type = relay.TensorType((2, 3, 4), "float32")
+    y_type = relay.TensorType((3, 1), "float32")
+    # We need to generate inputs far apart to get correct numerical gradients
+    # (otherwise adding epsilon may change comparison result). The gradient
+    # should always be zero for both inputs.
+    inputs = [
+        np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype),
+        np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype),
+    ]
+
+    x = relay.var("x", type_annotation=x_type)
+    y = relay.var("y", type_annotation=y_type)
+    fwd_func = relay.Function([x, y], relay.less_equal(x, y))
+    check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)
+
+
 if __name__ == "__main__":
     pytest.main()