You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/06/24 12:47:30 UTC

[tvm] branch main updated: [Relay][Training] Additional gradients (#8307)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b9d2899  [Relay][Training] Additional gradients (#8307)
b9d2899 is described below

commit b9d2899ae8adeb88bd95d633e9d1d8193f9c9560
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Thu Jun 24 05:46:56 2021 -0700

    [Relay][Training] Additional gradients (#8307)
---
 python/tvm/relay/op/_tensor_grad.py        | 62 +++++++++++++++++++++++++++---
 tests/python/relay/test_op_grad_level10.py | 24 ++++++++++++
 tests/python/relay/test_op_grad_level3.py  |  7 ++++
 tests/python/relay/test_op_grad_level4.py  | 33 ++++++++++++++++
 4 files changed, 121 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index d5b8910..09b1435 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name, unused-argument
-"""Backend compiler related feature registration"""
+"""Gradient definitions for Relay operators"""
 from tvm.topi.nn.utils import get_pad_tuple
 from tvm.topi.utils import get_const_tuple
 from tvm.error import OpError
@@ -527,10 +527,7 @@ def softmax_grad(orig, grad):
 @register_gradient("nn.log_softmax")
 def log_softmax_grad(orig, grad):
     """Gradient of log_softmax"""
-    x = orig.args[0]
-    sm = _nn.softmax(x, axis=orig.attrs.axis)
-    grad = grad / sm
-    return softmax_grad(sm, grad)
+    return [grad - _sum(grad, axis=orig.attrs.axis, keepdims=True) * exp(orig)]
 
 
 @register_gradient("nn.bias_add")
@@ -596,6 +593,12 @@ def cast_grad(orig, grad):
     return [cast_like(grad, x)]
 
 
+@register_gradient("cast_like")
+def cast_like_grad(orig, grad):
+    x, like = orig.args
+    return [cast_like(grad, x), zeros_like(like)]
+
+
 @register_gradient("nn.batch_flatten")
 def batch_flatten_grad(orig, grad):
     """Returns grad reshaped to data dims"""
@@ -873,3 +876,52 @@ def less_equal_grad(orig, grad):
     Returns the gradient of less_equal.
     """
     return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
+
+
+@register_gradient("not_equal")
+def not_equal_grad(orig, grad):
+    """
+    Returns the gradient of not_equal (just zeros).
+    """
+    return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
+
+
+@register_gradient("strided_slice")
+def strided_slice_grad(orig, grad):
+    """
+    Returns the gradient of strided_slice, which is equal to grad where the
+    input was sliced and zero elsewhere.
+    """
+    assert orig.attrs.axes is None, "grad for strided_slice with axes is not yet supported"
+    x = orig.args[0]
+    begin = get_const_tuple(orig.attrs.begin)
+    end = get_const_tuple(orig.attrs.end)
+    strides = get_const_tuple(orig.attrs.strides)
+    if orig.attrs.slice_mode == "size":
+        # convert sizes to ending indices and ignore strides
+        end = list(end)
+        for i, (start, size) in enumerate(zip(begin, end)):
+            if size == -1:
+                end[i] = int(x.checked_type.shape[i])
+            else:
+                end[i] = start + size
+        strides = None
+    else:
+        assert orig.attrs.slice_mode == "end"
+    return [strided_set(zeros_like(x), grad, begin, end, strides)]
+
+
+@register_gradient("one_hot")
+def one_hot_grad(orig, grad):
+    """
+    Returns the gradient of one_hot, which is the sum of grad at on and off
+    indices for on_value and off_value respectively.
+    """
+    indices, on_value, off_value = orig.args
+
+    g_zeros = zeros_like(grad)
+    on_mask = equal(orig, on_value)
+    grad_on = _sum(where(on_mask, grad, g_zeros))
+    grad_off = _sum(where(on_mask, g_zeros, grad))
+
+    return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, off_value)]
diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py
index 4a6ffb9..e2145f7 100644
--- a/tests/python/relay/test_op_grad_level10.py
+++ b/tests/python/relay/test_op_grad_level10.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+import numpy as np
 
 from tvm import relay
 from tvm.relay.testing import check_grad
@@ -72,5 +73,28 @@ def test_reverse_reshape_grad():
     check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0))))
 
 
+def test_one_hot_grad():
+    indices_shape = (3, 4)
+    depth = 5
+    axis = -1
+
+    for indices_dtype in ["int32", "int64"]:
+        for val_dtype in ["float32", "float64"]:
+            inputs = [
+                np.random.randint(depth, size=indices_shape, dtype=indices_dtype),
+                np.array(np.random.randn() * 1e-5).astype(val_dtype),
+                np.array(np.random.randn() * 1e-5).astype(val_dtype),
+            ]
+            test_inputs = inputs[1:]
+
+            indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype)
+            on_val = relay.var("on_val", shape=tuple(), dtype=val_dtype)
+            off_val = relay.var("off_val", shape=tuple(), dtype=val_dtype)
+            y = relay.one_hot(indices, on_val, off_val, depth, axis, val_dtype)
+            f = relay.Function([indices, on_val, off_val], y)
+
+            check_grad(f, inputs=inputs, test_inputs=test_inputs)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 821e10f..ae3fc26 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -69,6 +69,13 @@ def test_cast_grad():
     check_grad(fwd_func)
 
 
+def test_cast_like_grad():
+    data = relay.var("data", shape=(10, 4), dtype="float32")
+    like = relay.var("like", shape=(1,), dtype="float64")
+    fwd_func = relay.Function([data, like], relay.cast_like(data, like))
+    check_grad(fwd_func)
+
+
 def test_copy_grad():
     data = relay.var("data", relay.TensorType((10, 4), "float64"))
     fwd_func = relay.Function([data], relay.copy(data))
diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py
index 0f73e89..17d30ca 100644
--- a/tests/python/relay/test_op_grad_level4.py
+++ b/tests/python/relay/test_op_grad_level4.py
@@ -86,5 +86,38 @@ def test_less_equal_grad():
     check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)
 
 
+def test_not_equal_grad():
+    x_type = relay.TensorType((2, 3, 4), "float32")
+    y_type = relay.TensorType((3, 1), "float32")
+    # We need to generate inputs far apart to get correct numerical gradients
+    # (otherwise adding epsilon may change comparison result). The gradient
+    # should always be zero for both inputs.
+    inputs = [
+        np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype),
+        np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype),
+    ]
+
+    x = relay.var("x", type_annotation=x_type)
+    y = relay.var("y", type_annotation=y_type)
+    fwd_func = relay.Function([x, y], relay.not_equal(x, y))
+    check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)
+
+
+def test_strided_slice_grad():
+    def check(sh, dtype, begin, end, strides, slice_mode):
+        x = relay.var("x", shape=sh, dtype=dtype)
+        f = relay.Function(
+            [x],
+            relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode),
+        )
+        check_grad(f)
+
+    check((2, 3, 4), "float32", (0, 1, 0), (-1, -1, 1), (1, 1, 1), "size")
+    check((2, 3, 4), "float32", (0, 1, 0), (2, 3, 1), (1, 1, 1), "end")
+    # check that strides are properly ignored when using "size" mode
+    check((2, 3, 4), "float32", (0, 0, 0), (-1, -1, -1), (1, 1, 2), "size")
+    check((2, 3, 4), "float32", (0, 0, 0), (2, 3, 4), (1, 1, 2), "end")
+
+
 if __name__ == "__main__":
     pytest.main()