You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/05/17 21:08:14 UTC

[tvm] branch main updated: logsoftmax reusing the softmax function (#11141)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e2f869eea logsoftmax reusing the softmax function (#11141)
0e2f869eea is described below

commit 0e2f869eeadbb349f849ed2add86a622e97053cd
Author: czh978 <41...@users.noreply.github.com>
AuthorDate: Wed May 18 05:08:08 2022 +0800

    logsoftmax reusing the softmax function (#11141)
    
    Co-authored-by: caizihua <97...@qq.com>
---
 python/tvm/relay/frontend/onnx.py | 25 +++++++------------------
 1 file changed, 7 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 81f12c2d81..e68daca4c4 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2412,30 +2412,18 @@ class LogSoftmax(OnnxOpConverter):
     """Operator converter for Softmax."""
 
     @classmethod
-    def run_calculation(cls, x, axes):
+    def run_calculation(cls, inputs, attr, params, opset):
         """Run the calculation for Log Softmax calculation."""
-        m = _op.max(x, axes, keepdims=True)
-        e = _op.exp(x - m)
-        s = _op.sum(e, axes, keepdims=True)
-        return x - m - _op.log(s)
+        res = Softmax.get_converter(opset)(inputs, attr, params)
+        return _op.log(res)
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        axis = attr.get("axis", 1)
-        ndim = len(infer_shape(inputs[0]))
-        if axis < 0:
-            axis += ndim
-        axes = list(range(axis, ndim))
-        return cls.run_calculation(inputs[0], axes)
+        return cls.run_calculation(inputs, attr, params, opset=1)
 
     @classmethod
     def _impl_v13(cls, inputs, attr, params):
-        axis = attr.get("axis", -1)
-        ndim = len(infer_shape(inputs[0]))
-        if axis < 0:
-            axis += ndim
-        axes = [axis]
-        return cls.run_calculation(inputs[0], axes)
+        return cls.run_calculation(inputs, attr, params, opset=13)
 
 
 class Hardmax(OnnxOpConverter):
@@ -4852,7 +4840,8 @@ class SoftmaxCrossEntropyLoss(OnnxOpConverter):
             weight_tensor = None
 
         get_log_prob = attr["tvm_custom"]["num_outputs"] == 2
-        log_softmax_tensor = LogSoftmax.run_calculation(input_tensor, axes=[1])
+        log_softmax_attr = {"axis": 1}
+        log_softmax_tensor = LogSoftmax.get_converter(13)([input_tensor], log_softmax_attr, None)
 
         loss, weight_total = NegativeLogLikelihoodLoss.run_calculation(
             log_softmax_tensor,