You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/25 21:11:06 UTC
[tvm] branch main updated: [Relay][Training] Add more gradients
(#7323)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6f75cff [Relay][Training] Add more gradients (#7323)
6f75cff is described below
commit 6f75cffb64f20e72a2fad425ce58d0fd32c0d4c8
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Mon Jan 25 13:10:48 2021 -0800
[Relay][Training] Add more gradients (#7323)
* add more gradients
* add documentation
---
python/tvm/relay/op/_tensor_grad.py | 54 +++++++++++++++++++++++++++----
tests/python/relay/test_op_grad_level1.py | 8 +++++
tests/python/relay/test_op_grad_level3.py | 7 ++++
tests/python/relay/test_op_grad_level4.py | 37 ++++++++++++++++++++-
4 files changed, 99 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index 9c84411..c9a20a3 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -357,16 +357,24 @@ def global_avg_pool2d_grad(orig, grad):
return [pool_grad]
-# not implemented, this is only for testing.
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
+ """
+ Returns the gradient of concatenate, which is just the downstream gradient
+ split across the inputs.
+ """
assert len(orig.args) == 1
t = orig.args[0]
- x = TupleGetItem(t, 0)
- y = TupleGetItem(t, 1)
- # Assume only two element in tuple rn.
- # In the real implementation, concatenate_grad probably need to be implemented by an operator.
- return [Tuple([zeros_like(x), zeros_like(y)])]
+
+ # calculate split indices. TODO(@altanh): support Any?
+ axis_dims = [ty.shape[orig.attrs.axis] for ty in t.checked_type.fields]
+ splits, cumsum = [], 0
+ for dim in axis_dims[:-1]:
+ cumsum += dim
+ splits.append(cumsum)
+
+ grads = split(grad, tuple(splits), axis=orig.attrs.axis).tuple_value
+ return [grads]
@register_gradient("nn.conv2d")
@@ -808,5 +816,39 @@ def arange_grad(orig, grad):
@register_gradient("gather_nd")
def gather_nd_grad(orig, grad):
+ """
+ Returns the gradient of gather_nd, which is simply scatter_nd.
+ """
data, indices = orig.args
return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
+
+
+@register_gradient("reshape_like")
+def reshape_like_grad(orig, grad):
+ """
+ Returns the gradient of reshape_like.
+ """
+ data, shape_like = orig.args
+ return [reshape_like(grad, data), zeros_like(shape_like)]
+
+
+@register_gradient("where")
+def where_grad(orig, grad):
+ """
+ Returns the gradient of where.
+ """
+ cond, x, y = orig.args
+ g_zeros = zeros_like(grad)
+
+ grad_x = collapse_sum_like(where(cond, grad, g_zeros), x)
+ grad_y = collapse_sum_like(where(cond, g_zeros, grad), y)
+
+ return [zeros_like(cond), grad_x, grad_y]
+
+
+@register_gradient("less_equal")
+def less_equal_grad(orig, grad):
+ """
+ Returns the gradient of less_equal.
+ """
+ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py
index cac07c4..a79be86 100644
--- a/tests/python/relay/test_op_grad_level1.py
+++ b/tests/python/relay/test_op_grad_level1.py
@@ -150,5 +150,13 @@ def test_expand_dims_grad():
check_grad(fwd_func)
+def test_concatenate_grad():
+ x = relay.var("x", shape=(2, 2, 5))
+ y = relay.var("y", shape=(2, 1, 5))
+ z = relay.var("z", shape=(2, 4, 5))
+ fwd_func = relay.Function([x, y, z], relay.concatenate([x, y, z], axis=1))
+ check_grad(fwd_func)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 98ff62e..0c89aa7 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -126,5 +126,12 @@ def test_gather_nd_grad():
check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[data_np])
+def test_reshape_like_grad():
+ data = relay.var("data", shape=(2, 3, 4), dtype="float32")
+ shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
+ fwd_func = relay.Function([data, shape_like], relay.reshape_like(data, shape_like))
+ check_grad(fwd_func)
+
+
if __name__ == "__main__":
pytest.main()
diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py
index d479221..0f73e89 100644
--- a/tests/python/relay/test_op_grad_level4.py
+++ b/tests/python/relay/test_op_grad_level4.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+import numpy as np
from tvm import relay
-from tvm.relay.testing import check_grad
+from tvm.relay.testing import check_grad, _np_randn_from_type
def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=False):
@@ -51,5 +52,39 @@ def test_max_grad():
verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)
+def test_where_grad():
+ cond_type = relay.TensorType((2, 3, 4), "int32")
+ lhs_type = relay.TensorType((1, 3, 4), "float32")
+ rhs_type = relay.TensorType((2, 1, 4), "float32")
+ inputs = [
+ np.random.randint(2, size=cond_type.concrete_shape, dtype=cond_type.dtype),
+ _np_randn_from_type(lhs_type, scale=1e-5),
+ _np_randn_from_type(rhs_type, scale=1e-5),
+ ]
+
+ cond = relay.var("cond", type_annotation=cond_type)
+ lhs = relay.var("lhs", type_annotation=lhs_type)
+ rhs = relay.var("rhs", type_annotation=rhs_type)
+ fwd_func = relay.Function([cond, lhs, rhs], relay.where(cond, lhs, rhs))
+ check_grad(fwd_func, inputs=inputs, test_inputs=inputs[1:])
+
+
+def test_less_equal_grad():
+ x_type = relay.TensorType((2, 3, 4), "float32")
+ y_type = relay.TensorType((3, 1), "float32")
+ # We need to generate inputs far apart to get correct numerical gradients
+ # (otherwise adding epsilon may change comparison result). The gradient
+ # should always be zero for both inputs.
+ inputs = [
+ np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype),
+ np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype),
+ ]
+
+ x = relay.var("x", type_annotation=x_type)
+ y = relay.var("y", type_annotation=y_type)
+ fwd_func = relay.Function([x, y], relay.less_equal(x, y))
+ check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)
+
+
if __name__ == "__main__":
pytest.main()