You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/06/24 12:47:30 UTC
[tvm] branch main updated: [Relay][Training] Additional gradients
(#8307)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b9d2899 [Relay][Training] Additional gradients (#8307)
b9d2899 is described below
commit b9d2899ae8adeb88bd95d633e9d1d8193f9c9560
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Thu Jun 24 05:46:56 2021 -0700
[Relay][Training] Additional gradients (#8307)
---
python/tvm/relay/op/_tensor_grad.py | 62 +++++++++++++++++++++++++++---
tests/python/relay/test_op_grad_level10.py | 24 ++++++++++++
tests/python/relay/test_op_grad_level3.py | 7 ++++
tests/python/relay/test_op_grad_level4.py | 33 ++++++++++++++++
4 files changed, 121 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index d5b8910..09b1435 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
-"""Backend compiler related feature registration"""
+"""Gradient definitions for Relay operators"""
from tvm.topi.nn.utils import get_pad_tuple
from tvm.topi.utils import get_const_tuple
from tvm.error import OpError
@@ -527,10 +527,7 @@ def softmax_grad(orig, grad):
@register_gradient("nn.log_softmax")
def log_softmax_grad(orig, grad):
"""Gradient of log_softmax"""
- x = orig.args[0]
- sm = _nn.softmax(x, axis=orig.attrs.axis)
- grad = grad / sm
- return softmax_grad(sm, grad)
+ return [grad - _sum(grad, axis=orig.attrs.axis, keepdims=True) * exp(orig)]
@register_gradient("nn.bias_add")
@@ -596,6 +593,12 @@ def cast_grad(orig, grad):
return [cast_like(grad, x)]
+@register_gradient("cast_like")
+def cast_like_grad(orig, grad):
+ x, like = orig.args
+ return [cast_like(grad, x), zeros_like(like)]
+
+
@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
@@ -873,3 +876,52 @@ def less_equal_grad(orig, grad):
Returns the gradient of less_equal.
"""
return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
+
+
+@register_gradient("not_equal")
+def not_equal_grad(orig, grad):
+ """
+ Returns the gradient of not_equal (just zeros).
+ """
+ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
+
+
+@register_gradient("strided_slice")
+def strided_slice_grad(orig, grad):
+ """
+ Returns the gradient of strided_slice, which is equal to grad where the
+ input was sliced and zero elsewhere.
+ """
+ assert orig.attrs.axes is None, "grad for strided_slice with axes is not yet supported"
+ x = orig.args[0]
+ begin = get_const_tuple(orig.attrs.begin)
+ end = get_const_tuple(orig.attrs.end)
+ strides = get_const_tuple(orig.attrs.strides)
+ if orig.attrs.slice_mode == "size":
+ # convert sizes to ending indices and ignore strides
+ end = list(end)
+ for i, (start, size) in enumerate(zip(begin, end)):
+ if size == -1:
+ end[i] = int(x.checked_type.shape[i])
+ else:
+ end[i] = start + size
+ strides = None
+ else:
+ assert orig.attrs.slice_mode == "end"
+ return [strided_set(zeros_like(x), grad, begin, end, strides)]
+
+
+@register_gradient("one_hot")
+def one_hot_grad(orig, grad):
+ """
+ Returns the gradient of one_hot, which is the sum of grad at on and off
+ indices for on_value and off_value respectively.
+ """
+ indices, on_value, off_value = orig.args
+
+ g_zeros = zeros_like(grad)
+ on_mask = equal(orig, on_value)
+ grad_on = _sum(where(on_mask, grad, g_zeros))
+ grad_off = _sum(where(on_mask, g_zeros, grad))
+
+ return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, off_value)]
diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py
index 4a6ffb9..e2145f7 100644
--- a/tests/python/relay/test_op_grad_level10.py
+++ b/tests/python/relay/test_op_grad_level10.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+import numpy as np
from tvm import relay
from tvm.relay.testing import check_grad
@@ -72,5 +73,28 @@ def test_reverse_reshape_grad():
check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0))))
+def test_one_hot_grad():
+ indices_shape = (3, 4)
+ depth = 5
+ axis = -1
+
+ for indices_dtype in ["int32", "int64"]:
+ for val_dtype in ["float32", "float64"]:
+ inputs = [
+ np.random.randint(depth, size=indices_shape, dtype=indices_dtype),
+ np.array(np.random.randn() * 1e-5).astype(val_dtype),
+ np.array(np.random.randn() * 1e-5).astype(val_dtype),
+ ]
+ test_inputs = inputs[1:]
+
+ indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype)
+ on_val = relay.var("on_val", shape=tuple(), dtype=val_dtype)
+ off_val = relay.var("off_val", shape=tuple(), dtype=val_dtype)
+ y = relay.one_hot(indices, on_val, off_val, depth, axis, val_dtype)
+ f = relay.Function([indices, on_val, off_val], y)
+
+ check_grad(f, inputs=inputs, test_inputs=test_inputs)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 821e10f..ae3fc26 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -69,6 +69,13 @@ def test_cast_grad():
check_grad(fwd_func)
+def test_cast_like_grad():
+ data = relay.var("data", shape=(10, 4), dtype="float32")
+ like = relay.var("like", shape=(1,), dtype="float64")
+ fwd_func = relay.Function([data, like], relay.cast_like(data, like))
+ check_grad(fwd_func)
+
+
def test_copy_grad():
data = relay.var("data", relay.TensorType((10, 4), "float64"))
fwd_func = relay.Function([data], relay.copy(data))
diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py
index 0f73e89..17d30ca 100644
--- a/tests/python/relay/test_op_grad_level4.py
+++ b/tests/python/relay/test_op_grad_level4.py
@@ -86,5 +86,38 @@ def test_less_equal_grad():
check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)
+def test_not_equal_grad():
+ x_type = relay.TensorType((2, 3, 4), "float32")
+ y_type = relay.TensorType((3, 1), "float32")
+ # We need to generate inputs far apart to get correct numerical gradients
+ # (otherwise adding epsilon may change comparison result). The gradient
+ # should always be zero for both inputs.
+ inputs = [
+ np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype),
+ np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype),
+ ]
+
+ x = relay.var("x", type_annotation=x_type)
+ y = relay.var("y", type_annotation=y_type)
+ fwd_func = relay.Function([x, y], relay.not_equal(x, y))
+ check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)
+
+
+def test_strided_slice_grad():
+ def check(sh, dtype, begin, end, strides, slice_mode):
+ x = relay.var("x", shape=sh, dtype=dtype)
+ f = relay.Function(
+ [x],
+ relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode),
+ )
+ check_grad(f)
+
+ check((2, 3, 4), "float32", (0, 1, 0), (-1, -1, 1), (1, 1, 1), "size")
+ check((2, 3, 4), "float32", (0, 1, 0), (2, 3, 1), (1, 1, 1), "end")
+ # check that strides are properly ignored when using "size" mode
+ check((2, 3, 4), "float32", (0, 0, 0), (-1, -1, -1), (1, 1, 2), "size")
+ check((2, 3, 4), "float32", (0, 0, 0), (2, 3, 4), (1, 1, 2), "end")
+
+
if __name__ == "__main__":
pytest.main()