You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2020/12/25 03:27:11 UTC
[tvm] branch main updated: [Relay] Add fast_softmax (#7163)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6ffd740 [Relay] Add fast_softmax (#7163)
6ffd740 is described below
commit 6ffd740ade7fe3d06c108d6c0167f93666ce03ed
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Thu Dec 24 19:26:52 2020 -0800
[Relay] Add fast_softmax (#7163)
* [Relay] Add fast_softmax
* fix
* fix
---
python/tvm/relay/op/nn/_nn.py | 5 ++++
python/tvm/relay/op/nn/nn.py | 23 +++++++++++++++
python/tvm/relay/op/strategy/generic.py | 14 ++++++++++
python/tvm/topi/nn/softmax.py | 45 ++++++++++++++++++++++++++++--
src/relay/op/nn/nn.cc | 27 ++++++++++++++++++
src/relay/op/tensor/reduce.cc | 3 +-
tests/python/topi/python/test_topi_math.py | 10 ++++++-
7 files changed, 122 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index ee1b9e2..c5af5d8 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -42,6 +42,11 @@ reg.register_strategy("nn.softmax", strategy.softmax_strategy)
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
+# fast softmax
+reg.register_strategy("nn.fast_softmax", strategy.fast_softmax_strategy)
+reg.register_pattern("nn.fast_softmax", OpPattern.OPAQUE)
+
+
# log_softmax
reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index a8a0835..fef82e7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -698,6 +698,29 @@ def softmax(data, axis=-1):
return _make.softmax(data, axis)
+def fast_softmax(data, axis=-1):
+ r"""Computes softmax.
+ Use approximation to compute exponent for faster speed.
+
+ .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
+ .. note::
+ This operator can be optimized away for inference.
+
+ Parameters
+ ----------
+ data: tvm.relay.Expr
+ The input data to the operator.
+ axis: int, optional
+ The axis to sum over when computing softmax
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ return _make.fast_softmax(data, axis)
+
+
def log_softmax(data, axis=-1):
r"""Computes log softmax.
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 95b5d6a..6864266 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -154,6 +154,20 @@ def softmax_strategy(attrs, inputs, out_type, target):
return strategy
+@override_native_generic_func("fast_softmax_strategy")
+def fast_softmax_strategy(attrs, inputs, out_type, target):
+ """fast softmax generic strategy"""
+ # NOTE: This op does not have an optimized manual schedule,
+ # so it should only be used together with auto-scheduler.
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_softmax(topi.nn.fast_softmax),
+ naive_schedule,
+ name="fast_softmax.generic",
+ )
+ return strategy
+
+
# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py
index f6f20d7..6d2bb15 100644
--- a/python/tvm/topi/nn/softmax.py
+++ b/python/tvm/topi/nn/softmax.py
@@ -18,12 +18,12 @@
"""TVM operator for softmax and log_softmax compute."""
from __future__ import absolute_import
import tvm
-from tvm import te
+from tvm import te, topi
@tvm.te.tag_scope(tag="softmax_output")
def softmax(x, axis=-1):
- """Perform softmax activation on the data
+ """Perform softmax activation on the data.
Parameters
----------
@@ -38,6 +38,32 @@ def softmax(x, axis=-1):
output : tvm.te.Tensor
output shape is the same as input
"""
+ return softmax_common(x, axis, False)
+
+
+@tvm.te.tag_scope(tag="fast_softmax_output")
+def fast_softmax(x, axis=-1):
+ """Perform softmax activation on the data.
+ Use approximation to compute exponent for faster speed.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ can be any dimension
+
+ axis : int
+ channel axis
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ output shape is the same as input
+ """
+ return softmax_common(x, axis, True)
+
+
+def softmax_common(x, axis, use_fast_exp):
+ """The common part of softmax and fast_softmax"""
shape = x.shape
if axis < 0:
axis = len(shape) + axis
@@ -57,6 +83,10 @@ def softmax(x, axis=-1):
eval_range = insert_reduce_index(indices, k1)
return tvm.te.max(x[eval_range], axis=k1)
+ def _compute_delta(max_elem, *indices):
+ non_reduce_indices = get_non_reduce_indices(indices)
+ return x[indices] - max_elem[non_reduce_indices]
+
def _compute_exp(max_elem, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return te.exp(x[indices] - max_elem[non_reduce_indices])
@@ -71,7 +101,16 @@ def softmax(x, axis=-1):
reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = te.compute(reduced_shape, _compute_max, name="T_softmax_maxelem")
- exp = te.compute(shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp")
+
+ if use_fast_exp:
+ delta = te.compute(
+ shape, lambda *indices: _compute_delta(max_elem, *indices), name="T_softmax_delta"
+ )
+ exp = topi.math.fast_exp(delta)
+ else:
+ exp = te.compute(
+ shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp"
+ )
expsum = te.compute(
reduced_shape, lambda *indices: _compute_expsum(exp, *indices), name="T_softmax_expsum"
)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index fbb6204..ce62242 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -310,6 +310,33 @@ RELAY_REGISTER_OP("nn.softmax")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
+// relay.fast_softmax
+TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.fast_softmax").set_body_typed([](Expr data, int axis) {
+ auto attrs = make_object<SoftmaxAttrs>();
+ attrs->axis = axis;
+ static const Op& op = Op::Get("nn.fast_softmax");
+ return Call(op, {data}, Attrs(attrs), {});
+});
+
+RELAY_REGISTER_OP("nn.fast_softmax")
+ .describe(R"code(Softmax layer.
+ Use approximation to compute exponent for faster speed.
+
+.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
+
+.. note::
+ This operator can be optimized away for inference.
+
+- **data**: The input data
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<SoftmaxAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(1)
+ .add_type_rel("Identity", IdentityRel);
+
// relay.nn.log_softmax
TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index afe4557..f611dc2 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -573,7 +573,8 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& i
count -= 1;
}
std::vector<Integer> expand_shape;
- auto sq_diff = topi::power(topi::subtract(data, mean), 2);
+ auto diff = topi::subtract(data, mean);
+ auto sq_diff = topi::multiply(diff, diff);
if (param->exclude) {
axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
ICHECK_NE(axes.size(), 0);
diff --git a/tests/python/topi/python/test_topi_math.py b/tests/python/topi/python/test_topi_math.py
index 6e119e7..74575dd 100644
--- a/tests/python/topi/python/test_topi_math.py
+++ b/tests/python/topi/python/test_topi_math.py
@@ -199,7 +199,7 @@ def test_cast():
def test_fastmath():
def test_apply(func, name, f_numpy, low, high, step, dtype="float32"):
- a_np = np.arange(low, high, step).astype(dtype)
+ a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1))
b_np = f_numpy(a_np)
A = te.placeholder(a_np.shape, dtype=dtype, name="A")
B = func(A)
@@ -224,6 +224,14 @@ def test_fastmath():
test_apply(topi.fast_exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
test_apply(topi.fast_erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
test_apply(topi.fast_tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
+ test_apply(
+ topi.nn.fast_softmax,
+ "fast_softmax",
+ tvm.topi.testing.softmax_python,
+ low=-10,
+ high=10,
+ step=0.01,
+ )
if __name__ == "__main__":