You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/30 00:43:52 UTC

[tvm] branch main updated: support any shape and axis for log softmax (#11951)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 898946fec6 support any shape and axis for log softmax (#11951)
898946fec6 is described below

commit 898946fec60898b8fa753d6f0cdf8ebc86c9827a
Author: Altan Haan <31...@users.noreply.github.com>
AuthorDate: Wed Jun 29 17:43:48 2022 -0700

    support any shape and axis for log softmax (#11951)
---
 python/tvm/topi/nn/softmax.py                 | 42 +++++++++++----
 python/tvm/topi/testing/softmax_python.py     | 28 +++++-----
 python/tvm/topi/x86/nn.py                     |  2 +-
 tests/python/relay/test_op_level1.py          | 74 ++++++++++++++-------------
 tests/python/topi/python/test_topi_softmax.py |  2 +-
 5 files changed, 84 insertions(+), 64 deletions(-)

diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py
index cb6d5b321e..2d6921b26d 100644
--- a/python/tvm/topi/nn/softmax.py
+++ b/python/tvm/topi/nn/softmax.py
@@ -136,16 +136,38 @@ def log_softmax(x, axis=-1):
     output : tvm.te.Tensor
         2-D output with same shape
     """
-    assert len(x.shape) == 2, "only support 2-dim log softmax"
-    # pylint: disable=R1714
-    assert axis == -1 or axis == len(x.shape) - 1, "only support last axis log softmax"
-    m, n = x.shape
-    k = te.reduce_axis((0, n), name="k")
-    max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k))
-    k = te.reduce_axis((0, n), name="k")
-    expsum = te.compute((m,), lambda i: te.sum(te.exp(x[i, k] - max_elem[i]), axis=k))
+    shape = x.shape
+    if axis < 0:
+        axis = len(shape) + axis
+    if axis >= len(shape):
+        ValueError("axis parameter should be less than input dim")
+
+    k1 = te.reduce_axis((0, shape[axis]), name="k")
+    k2 = te.reduce_axis((0, shape[axis]), name="k")
+
+    def insert_reduce_index(indices, reduce_index):
+        return indices[:axis] + (reduce_index,) + indices[axis:]
+
+    def get_non_reduce_indices(indices):
+        return tuple([var for (i, var) in enumerate(indices) if i != axis])
+
+    def _compute_max(*indices):
+        eval_range = insert_reduce_index(indices, k1)
+        return tvm.te.max(x[eval_range], axis=k1)
+
+    def _compute_expsum(max_elem, *indices):
+        eval_range = insert_reduce_index(indices, k2)
+        return te.sum(te.exp(x[eval_range] - max_elem[indices]), axis=k2)
+
+    def _normalize(max_elem, expsum, *indices):
+        non_reduce_indices = get_non_reduce_indices(indices)
+        return x[indices] - max_elem[non_reduce_indices] - te.log(expsum[non_reduce_indices])
+
+    reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
+    max_elem = te.compute(reduced_shape, _compute_max, name="T_softmax_maxelem")
+    expsum = te.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices))
     return te.compute(
-        x.shape,
-        lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]),
+        shape,
+        lambda *indices: _normalize(max_elem, expsum, *indices),
         attrs={"axis": axis},
     )
diff --git a/python/tvm/topi/testing/softmax_python.py b/python/tvm/topi/testing/softmax_python.py
index da2893d1fa..6be5d48a67 100644
--- a/python/tvm/topi/testing/softmax_python.py
+++ b/python/tvm/topi/testing/softmax_python.py
@@ -19,43 +19,39 @@
 import numpy as np
 
 
-def softmax_python(a_np):
+def softmax_python(a_np, axis=1):
     """Softmax operator.
     Parameters
     ----------
     a_np : numpy.ndarray
-        2-D input data
+        N-D input data
 
     Returns
     -------
     output_np : numpy.ndarray
-        2-D output with same shape
+        N-D output with same shape
     """
-    assert len(a_np.shape) == 2, "only support 2-dim softmax"
-    max_elem = np.amax(a_np, axis=1)
-    max_elem = max_elem.reshape(max_elem.shape[0], 1)
+    max_elem = np.amax(a_np, axis=axis, keepdims=True)
     e = np.exp(a_np - max_elem)
-    expsum = np.sum(e, axis=1)
-    out_np = e / expsum[:, None]
+    expsum = np.sum(e, axis=axis, keepdims=True)
+    out_np = e / expsum
     return out_np
 
 
-def log_softmax_python(a_np):
+def log_softmax_python(a_np, axis=1):
     """Log_softmax operator.
     Parameters
     ----------
     a_np : numpy.ndarray
-        2-D input data
+        N-D input data
 
     Returns
     -------
     output_np : numpy.ndarray
-        2-D output with same shape
+        N-D output with same shape
     """
-    assert len(a_np.shape) == 2, "only support 2-dim log_softmax"
-    max_elem = np.amax(a_np, axis=1)
-    max_elem = max_elem.reshape(max_elem.shape[0], 1)
+    max_elem = np.amax(a_np, axis=axis, keepdims=True)
     e = np.exp(a_np - max_elem)
-    expsum = np.sum(e, axis=1)
-    out_np = a_np - max_elem - np.log(expsum[:, None])
+    expsum = np.sum(e, axis=axis, keepdims=True)
+    out_np = a_np - max_elem - np.log(expsum)
     return out_np
diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py
index 9b6754c5e8..5475fc772e 100644
--- a/python/tvm/topi/x86/nn.py
+++ b/python/tvm/topi/x86/nn.py
@@ -39,7 +39,7 @@ def _schedule_softmax(softmax_op, s, outs):
         delta = None
         max_elem = softmax_op.input_tensors[1]
         expsum = softmax_op.input_tensors[2]
-        axis = 1
+        axis = int(softmax_op.attrs["axis"])
     else:
         raise ValueError(
             "Tag is expected to be softmax_output or log_softmax_output. \
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 44df40d3b0..4ce422ae88 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -249,46 +249,48 @@ def test_expand_dims_infer_type():
 
 @tvm.testing.uses_gpu
 def test_softmax():
-    for dtype in ["float16", "float32"]:
-        # Softmax accuracy for float16 is poor
-        if dtype == "float16":
-            return
-        shape = (10, 4)
-        x = relay.var("x", shape=shape, dtype=dtype)
-        y = relay.nn.softmax(x, axis=1)
-        assert "nn.softmax" in y.astext()
-        yy = run_infer_type(y)
-        assert yy.checked_type == relay.TensorType(shape, dtype)
-        func = relay.Function([x], y)
-        x_data = np.random.uniform(size=shape).astype(dtype)
-        ref_res = tvm.topi.testing.softmax_python(x_data)
-        for target, dev in tvm.testing.enabled_targets():
-            op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
-                x_data
-            )
-            np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
+    for shape in [(10, 4), (10, 5, 4)]:
+        for dtype in ["float16", "float32"]:
+            # Softmax accuracy for float16 is poor
+            if dtype == "float16":
+                continue
+            x = relay.var("x", shape=shape, dtype=dtype)
+            y = relay.nn.softmax(x, axis=1)
+            assert "nn.softmax" in y.astext()
+            yy = run_infer_type(y)
+            assert yy.checked_type == relay.TensorType(shape, dtype)
+            func = relay.Function([x], y)
+            x_data = np.random.uniform(size=shape).astype(dtype)
+            ref_res = tvm.topi.testing.softmax_python(x_data, axis=1)
+            for target, dev in tvm.testing.enabled_targets():
+                op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
+                    x_data
+                )
+                np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
 
 
 @tvm.testing.uses_gpu
 def test_log_softmax():
-    for dtype in ["float16", "float32"]:
-        # Softmax accuracy for float16 is poor
-        if dtype == "float16":
-            return
-        shape = (10, 4)
-        x = relay.var("x", shape=shape, dtype=dtype)
-        y = relay.nn.log_softmax(x, axis=1)
-        assert "nn.log_softmax" in y.astext()
-        yy = run_infer_type(y)
-        assert yy.checked_type == relay.TensorType(shape, dtype)
-        func = relay.Function([x], y)
-        x_data = np.random.uniform(size=shape).astype(dtype)
-        ref_res = tvm.topi.testing.log_softmax_python(x_data)
-        for target, dev in tvm.testing.enabled_targets():
-            op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
-                x_data
-            )
-            np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
+    for shape in [(10, 4), (10, 5, 4)]:
+        for dtype in ["float16", "float32"]:
+            # Softmax accuracy for float16 is poor
+            if dtype == "float16":
+                continue
+            x = relay.var("x", shape=shape, dtype=dtype)
+            y = relay.nn.log_softmax(x, axis=1)
+            assert "nn.log_softmax" in y.astext()
+            yy = run_infer_type(y)
+            assert yy.checked_type == relay.TensorType(shape, dtype)
+            func = relay.Function([x], y)
+            x_data = np.random.uniform(size=shape).astype(dtype)
+            ref_res = tvm.topi.testing.log_softmax_python(x_data, axis=1)
+            for target, dev in tvm.testing.enabled_targets():
+                if target == "nvptx":
+                    continue
+                op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
+                    x_data
+                )
+                np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
 
 
 @tvm.testing.uses_gpu
diff --git a/tests/python/topi/python/test_topi_softmax.py b/tests/python/topi/python/test_topi_softmax.py
index 8243211a86..8e5e039b14 100644
--- a/tests/python/topi/python/test_topi_softmax.py
+++ b/tests/python/topi/python/test_topi_softmax.py
@@ -50,7 +50,7 @@ configs = {
     "log_softmax": {
         "topi": topi.nn.log_softmax,
         "ref": tvm.topi.testing.log_softmax_python,
-        "dimensions": [2],
+        "dimensions": [2, 3],
         "axis": [1],
     },
 }