You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2020/12/25 03:27:11 UTC

[tvm] branch main updated: [Relay] Add fast_softmax (#7163)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ffd740  [Relay] Add fast_softmax (#7163)
6ffd740 is described below

commit 6ffd740ade7fe3d06c108d6c0167f93666ce03ed
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Thu Dec 24 19:26:52 2020 -0800

    [Relay] Add fast_softmax (#7163)
    
    * [Relay] Add fast_softmax
    
    * fix
    
    * fix
---
 python/tvm/relay/op/nn/_nn.py              |  5 ++++
 python/tvm/relay/op/nn/nn.py               | 23 +++++++++++++++
 python/tvm/relay/op/strategy/generic.py    | 14 ++++++++++
 python/tvm/topi/nn/softmax.py              | 45 ++++++++++++++++++++++++++++--
 src/relay/op/nn/nn.cc                      | 27 ++++++++++++++++++
 src/relay/op/tensor/reduce.cc              |  3 +-
 tests/python/topi/python/test_topi_math.py | 10 ++++++-
 7 files changed, 122 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index ee1b9e2..c5af5d8 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -42,6 +42,11 @@ reg.register_strategy("nn.softmax", strategy.softmax_strategy)
 reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
 
 
+# fast softmax
+reg.register_strategy("nn.fast_softmax", strategy.fast_softmax_strategy)
+reg.register_pattern("nn.fast_softmax", OpPattern.OPAQUE)
+
+
 # log_softmax
 reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index a8a0835..fef82e7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -698,6 +698,29 @@ def softmax(data, axis=-1):
     return _make.softmax(data, axis)
 
 
+def fast_softmax(data, axis=-1):
+    r"""Computes softmax.
+    Use approximation to compute exponent for faster speed.
+
+    .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
+    .. note::
+        This operator can be optimized away for inference.
+
+    Parameters
+    ----------
+    data: tvm.relay.Expr
+        The input data to the operator.
+    axis: int, optional
+        The axis to sum over when computing softmax
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    return _make.fast_softmax(data, axis)
+
+
 def log_softmax(data, axis=-1):
     r"""Computes log softmax.
 
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 95b5d6a..6864266 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -154,6 +154,20 @@ def softmax_strategy(attrs, inputs, out_type, target):
     return strategy
 
 
+@override_native_generic_func("fast_softmax_strategy")
+def fast_softmax_strategy(attrs, inputs, out_type, target):
+    """fast softmax generic strategy"""
+    # NOTE: This op does not have an optimized manual schedule,
+    # so it should only be used together with auto-scheduler.
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_softmax(topi.nn.fast_softmax),
+        naive_schedule,
+        name="fast_softmax.generic",
+    )
+    return strategy
+
+
 # log_softmax
 @generic_func
 def schedule_log_softmax(attrs, outs, target):
diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py
index f6f20d7..6d2bb15 100644
--- a/python/tvm/topi/nn/softmax.py
+++ b/python/tvm/topi/nn/softmax.py
@@ -18,12 +18,12 @@
 """TVM operator for softmax and log_softmax compute."""
 from __future__ import absolute_import
 import tvm
-from tvm import te
+from tvm import te, topi
 
 
 @tvm.te.tag_scope(tag="softmax_output")
 def softmax(x, axis=-1):
-    """Perform softmax activation on the data
+    """Perform softmax activation on the data.
 
     Parameters
     ----------
@@ -38,6 +38,32 @@ def softmax(x, axis=-1):
     output : tvm.te.Tensor
         output shape is the same as input
     """
+    return softmax_common(x, axis, False)
+
+
+@tvm.te.tag_scope(tag="fast_softmax_output")
+def fast_softmax(x, axis=-1):
+    """Perform softmax activation on the data.
+    Use approximation to compute exponent for faster speed.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        can be any dimension
+
+    axis : int
+        channel axis
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        output shape is the same as input
+    """
+    return softmax_common(x, axis, True)
+
+
+def softmax_common(x, axis, use_fast_exp):
+    """The common part of softmax and fast_softmax"""
     shape = x.shape
     if axis < 0:
         axis = len(shape) + axis
@@ -57,6 +83,10 @@ def softmax(x, axis=-1):
         eval_range = insert_reduce_index(indices, k1)
         return tvm.te.max(x[eval_range], axis=k1)
 
+    def _compute_delta(max_elem, *indices):
+        non_reduce_indices = get_non_reduce_indices(indices)
+        return x[indices] - max_elem[non_reduce_indices]
+
     def _compute_exp(max_elem, *indices):
         non_reduce_indices = get_non_reduce_indices(indices)
         return te.exp(x[indices] - max_elem[non_reduce_indices])
@@ -71,7 +101,16 @@ def softmax(x, axis=-1):
 
     reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
     max_elem = te.compute(reduced_shape, _compute_max, name="T_softmax_maxelem")
-    exp = te.compute(shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp")
+
+    if use_fast_exp:
+        delta = te.compute(
+            shape, lambda *indices: _compute_delta(max_elem, *indices), name="T_softmax_delta"
+        )
+        exp = topi.math.fast_exp(delta)
+    else:
+        exp = te.compute(
+            shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp"
+        )
     expsum = te.compute(
         reduced_shape, lambda *indices: _compute_expsum(exp, *indices), name="T_softmax_expsum"
     )
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index fbb6204..ce62242 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -310,6 +310,33 @@ RELAY_REGISTER_OP("nn.softmax")
     .set_support_level(1)
     .add_type_rel("Identity", IdentityRel);
 
+// relay.fast_softmax
+TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.fast_softmax").set_body_typed([](Expr data, int axis) {
+  auto attrs = make_object<SoftmaxAttrs>();
+  attrs->axis = axis;
+  static const Op& op = Op::Get("nn.fast_softmax");
+  return Call(op, {data}, Attrs(attrs), {});
+});
+
+RELAY_REGISTER_OP("nn.fast_softmax")
+    .describe(R"code(Softmax layer.
+    Use approximation to compute exponent for faster speed.
+
+.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
+
+.. note::
+    This operator can be optimized away for inference.
+
+- **data**: The input data
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<SoftmaxAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(1)
+    .add_type_rel("Identity", IdentityRel);
+
 // relay.nn.log_softmax
 TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) {
   auto attrs = make_object<SoftmaxAttrs>();
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index afe4557..f611dc2 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -573,7 +573,8 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& i
     count -= 1;
   }
   std::vector<Integer> expand_shape;
-  auto sq_diff = topi::power(topi::subtract(data, mean), 2);
+  auto diff = topi::subtract(data, mean);
+  auto sq_diff = topi::multiply(diff, diff);
   if (param->exclude) {
     axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
     ICHECK_NE(axes.size(), 0);
diff --git a/tests/python/topi/python/test_topi_math.py b/tests/python/topi/python/test_topi_math.py
index 6e119e7..74575dd 100644
--- a/tests/python/topi/python/test_topi_math.py
+++ b/tests/python/topi/python/test_topi_math.py
@@ -199,7 +199,7 @@ def test_cast():
 
 def test_fastmath():
     def test_apply(func, name, f_numpy, low, high, step, dtype="float32"):
-        a_np = np.arange(low, high, step).astype(dtype)
+        a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1))
         b_np = f_numpy(a_np)
         A = te.placeholder(a_np.shape, dtype=dtype, name="A")
         B = func(A)
@@ -224,6 +224,14 @@ def test_fastmath():
     test_apply(topi.fast_exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
     test_apply(topi.fast_erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
     test_apply(topi.fast_tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
+    test_apply(
+        topi.nn.fast_softmax,
+        "fast_softmax",
+        tvm.topi.testing.softmax_python,
+        low=-10,
+        high=10,
+        step=0.01,
+    )
 
 
 if __name__ == "__main__":