You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/09/22 02:59:34 UTC

[incubator-tvm] branch master updated: [Torch] Clean up usage of try ... infer_value() ... except (#6504)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0448858  [Torch] Clean up usage of try ... infer_value() ... except (#6504)
0448858 is described below

commit 044885860842e6c9936b45ab4edbd23f0e3c727b
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Sep 22 11:59:00 2020 +0900

    [Torch] Clean up usage of try ... infer_value() ... except (#6504)
    
    * clean up infer value usage
    
    * try silence pylint
    
    * remove unused variable
    
    * make on_failuare optional
    
    * make on_success optional True
    
    Co-authored-by: masa <ma...@pop-os.localdomain>
---
 python/tvm/relay/frontend/common.py  | 17 ++++++++++
 python/tvm/relay/frontend/pytorch.py | 62 ++++++++++++++++--------------------
 2 files changed, 44 insertions(+), 35 deletions(-)

diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index e4d605a..027d6bd 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -563,6 +563,23 @@ def infer_value_simulated(input_val, params):
     return output_value
 
 
+def try_infer_value(val, on_success=None, on_failure=None):
+    """Try running infer_value on the input val, and if successful, return the inferred value or
+    pass it to on_success callback if provided. Otherwise, run on_failure callback if it is
+    provided, or return the input val as output. In each case, the second return value
+    indicates whether infer_value has succeeded or not.
+    """
+    try:
+        ret = infer_value(val, {}).asnumpy()
+        if on_success:
+            return on_success(ret), True
+        return ret, True
+    except Exception:
+        if on_failure:
+            return on_failure(), False
+        return val, False
+
+
 def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"):
     return _expr.var(name_hint, type_annotation, shape, dtype)
 
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9ceb9fc..c667b04 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
 # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
-# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
+# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
 """PT: PyTorch frontend."""
 import itertools
 import logging
@@ -36,6 +36,7 @@ from .. import transform
 from .common import AttrCvt, get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
+from .common import try_infer_value
 from .common import infer_value_simulated as _infer_value_simulated
 from .common import infer_type as _infer_type
 from ..prelude import Prelude, StaticTensorArrayOps
@@ -185,11 +186,8 @@ def _arange():
         def _get_value(val, dtype):
             # dtype is a tvm dtype
             if isinstance(val, _expr.Expr):
-                try:
-                    ret = _infer_value(_op.cast(val, dtype), {}).asnumpy()
-                    ret = _expr.const(ret, dtype)
-                except Exception:
-                    ret = _op.cast(val, dtype)
+                inp = _op.cast(val, dtype)
+                ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype))
             else:
                 ret = _create_typed_const(val, dtype)
             return ret
@@ -305,10 +303,7 @@ def _slice():
         dim = int(inputs[1])
         stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            try:
-                begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int))
-            except Exception:
-                begin[dim] = inputs[2]
+            begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
         else:
             begin[dim] = int(inputs[2])
 
@@ -329,10 +324,9 @@ def _slice():
             target_end = int(inputs[3])
         else:
             if isinstance(inputs[3], _expr.Expr):
-                try:
-                    target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
-                except Exception:
-                    target_end = inputs[3]
+                target_end, _ = try_infer_value(
+                    inputs[3], lambda ret: np.asscalar(ret.astype(np.int))
+                )
             else:
                 target_end = inputs[3]
 
@@ -457,10 +451,7 @@ def _topk():
         sort = bool(inputs[4])
 
         if isinstance(inputs[1], _expr.Expr):
-            try:
-                k = _infer_value(inputs[1], {}).asnumpy().tolist()
-            except Exception:
-                k = inputs[1]
+            k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist())
         else:
             k = inputs[1]
 
@@ -546,15 +537,15 @@ def _full_impl(data, fill_value, dtype):
                     size.append(dim)
                 new_shape.append(dim)
             else:
-                try:
-                    dim = int(_infer_value(dim, {}).asnumpy())
+                dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0)
+                new_shape.append(dim)
+
+                if success:
                     if isinstance(size, list):
                         size.append(dim)
-                    new_shape.append(dim)
-                except Exception:
+                else:
                     size = None
                     need_reshape = True
-                    new_shape.append(0)
         else:
             if isinstance(size, list):
                 size.append(dim)
@@ -1346,12 +1337,11 @@ def _reshape():
             if isinstance(s, _expr.Constant):
                 tmp_shape.append(int(s.data.asnumpy()))
             elif isinstance(s, _expr.Expr):
-                try:
-                    dim = int(_infer_value(s, {}).asnumpy())
-                    tmp_shape.append(dim)
-                except Exception:
+                dim, success = try_infer_value(s, lambda ret: int(ret))
+                tmp_shape.append(dim)
+
+                if not success:
                     is_dyn = True
-                    tmp_shape.append(s)
             else:
                 tmp_shape.append(s)
 
@@ -2312,13 +2302,15 @@ def _interpolate():
         if isinstance(inputs[1], _expr.Expr):
             out_size = inputs[1]
         elif isinstance(inputs[1], list):
-            try:
-                infer_res = [_infer_value(size, {}) for size in inputs[1]]
-                out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
-            except Exception:
-                h = _op.expand_dims(inputs[1][0], axis=0)
-                w = _op.expand_dims(inputs[1][1], axis=0)
-                out_size = _op.concatenate([h, w], axis=0)
+            out_size = []
+            for i in [0, 1]:
+                size, _ = try_infer_value(
+                    inputs[1][i],
+                    lambda ret: ret.astype(np.int),
+                    lambda: _op.expand_dims(inputs[1][i], axis=0),
+                )
+                out_size.append(size)
+            out_size = _op.concatenate(out_size, axis=0)
 
         data = inputs[0]
         align_corners = inputs[4]