You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/11 03:13:51 UTC

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #5699: [Frontend][TensorFlow] Improve Control Flow and TensorArray

zhiics commented on a change in pull request #5699:
URL: https://github.com/apache/incubator-tvm/pull/5699#discussion_r438517544



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2444,6 +2470,27 @@ def _in_while_loop(control_flow_node_map, op_name):
     return op_name in control_flow_node_map and \
             "LoopCond" in control_flow_node_map[op_name]
 
+class RewriteSubgraph(ExprMutator):
+    """
+    A helper class to rewrite expr in while loop function to variable

Review comment:
       Add a blank line

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2657,86 +2636,78 @@ class Loop:
           %6
         }
     """
-    def __init__(self, mod, loop_name, hash2tfnode,
-                 node_map, while_loop_name_set):
-        self.loop_vars = []
+    def __init__(self, mod, loop_name, lvar2expr):
         self.cond = None
         self.body = []
         self._loop = None
         self._mod = mod
         self._loop_name = loop_name
-        self._hash2tfnode = hash2tfnode
-        self._node_map = node_map
-        self._while_loop_name_set = while_loop_name_set
+        self._lvar2expr = lvar2expr
+        self.loop_vars = []
+
         self.aligned = False
 
     def _while_loop(self):
         """An internal API to create a Relay recursive call for a matched TF
         `while_loop` construct.
         """
+        bind_map = {}
         wl = tvm.relay.var('while_loop')
-
         sb = tvm.relay.scope_builder.ScopeBuilder()
 
-        loop_checker = LoopBound(self._loop_name,
-                                 self._hash2tfnode,
-                                 self._while_loop_name_set)
-        for body in self.body:
-            loop_checker.visit(body)
-
-        loop_vars = []
-        bind_map = {}
-        loop_var_hash_set = set()
-        for var in self.loop_vars:
-            loop_var_hash_set.add(s_hash(var))
-
-        extra_nodes = []
-        for extra_loop_var_name in loop_checker.extra_loop_var_names:
-            extra_loop_var_name = extra_loop_var_name.split(':')[0].split("^")[-1]
-            extra_node = self._node_map[extra_loop_var_name]
-            extra_node = extra_node if isinstance(extra_node, _expr.Tuple) else extra_node[0]
-            if s_hash(extra_node) not in loop_var_hash_set:
-                self.loop_vars.append(extra_node)
-                extra_nodes.append(extra_node)
-
-        for i, var in enumerate(self.loop_vars):
-            if not isinstance(var, _expr.Var):
-                var_chk = _infer_type(var, self._mod)
-                var_type = var_chk.checked_type
-            else:
-                var_type = var.type_annotation
-
-            v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type)
-            loop_vars.append(v)
-            bind_map[var] = v
+        lv_list = []
+        expr_list = []
+        extra_vars = []
+
+        for i, lv in enumerate(self.loop_vars):
+            if self._loop_name not in self._lvar2expr:
+                self._lvar2expr[self._loop_name] = {}
+
+            # Handle the case when loop var is not properly lifted.
+            # This can happen when loop var node name is set accidentally
+            # beginning with loop name.
+            if lv not in self._lvar2expr[self._loop_name]:
+                var_name = "{}_loop_var_{}".format(self._loop_name, i)
+                var_type = _infer_type(lv, self._mod).checked_type
+                loop_var = tvm.relay.var(var_name, type_annotation=var_type)
+                self._lvar2expr[self._loop_name][loop_var] = lv
+                bind_map[lv] = loop_var
+                self.loop_vars[i] = loop_var
+                lv = loop_var
+
+            lv_list.append(lv)
+            expr_list.append(self._lvar2expr[self._loop_name][lv])
+
+        if bind_map:
+            self.cond = rewrite_subgraph(self.cond, bind_map)
+            self.body = [rewrite_subgraph(b, bind_map) for b in self.body]
 
+        cond = tvm.relay.op.min(self.cond)
 
-        self.cond = rewrite_subgraph(self.cond, bind_map)
-        self.body = [rewrite_subgraph(b, bind_map) for b in self.body]
-
-        self.body_shape = []
-        for body in self.body:
-            current_node = body
-            shape = _infer_shape(current_node, self._mod)
-            while not isinstance(shape, (tuple, list)):
-                current_node = current_node.args[-1]
-                shape = _infer_shape(current_node, self._mod)
-            self.body_shape.append(shape)
+        for lv, exp in self._lvar2expr[self._loop_name].items():
+            if lv not in self.loop_vars:
+                var_checker = VarChecker(lv)
+                used = False
+                for bd in self.body + [cond]:
+                    var_checker.visit(bd)
+                    if var_checker.used:
+                        used = True

Review comment:
       can we just move the 3 statements under `if used1` here and remove `used`?

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1444,15 +1420,50 @@ def _impl(inputs, attr, params, mod):
         begin = _get_list_param(params, inputs[1])
         end = _get_list_param(params, inputs[2])
         stride = _get_list_param(params, inputs[3])
+
         begin_mask = int(attr.get('begin_mask', 0))
         end_mask = int(attr.get('end_mask', 0))
         ellipsis_mask = int(attr.get('ellipsis_mask', 0))
         new_axis_mask = int(attr.get('new_axis_mask', 0))
         shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
-        data_shape = _infer_shape(inputs[0], mod)
+        in_type = _infer_type(inputs[0], mod)
+        data_shape = get_const_tuple(in_type.checked_type.shape)
         data_dim = len(data_shape)
         stride_dim = len(stride)
 
+        # This is a special routine to handle strided_slice after shape_of.
+        # We need this since in some cases we want to do strided_slice on
+        # a partial symbolic shape, such as (1, ?), and get a static shape
+        # (1,). Directly slice on shape_of will result in fully dynamic shape.
+        # TODO(kevinthesun): Can we generalize this process with partial eval?
+        if isinstance(inputs[0], _expr.Call) and "shape_of" in str(inputs[0].op):
+            bg = begin[0]
+            ed = end[0]
+            st = stride[0]
+
+            if ed <= 0 < st:
+                ed += data_shape[0]
+
+            in_shape = _infer_shape(inputs[0].args[0], mod)
+            dtype = in_type.checked_type.dtype
+            out_data = []
+            idx = bg
+            is_constant = True

Review comment:
       It looks `is_constant` is not needed.
   
   We can just do `if idx < ed:` with a comment.

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1444,15 +1420,50 @@ def _impl(inputs, attr, params, mod):
         begin = _get_list_param(params, inputs[1])
         end = _get_list_param(params, inputs[2])
         stride = _get_list_param(params, inputs[3])
+
         begin_mask = int(attr.get('begin_mask', 0))
         end_mask = int(attr.get('end_mask', 0))
         ellipsis_mask = int(attr.get('ellipsis_mask', 0))
         new_axis_mask = int(attr.get('new_axis_mask', 0))
         shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
-        data_shape = _infer_shape(inputs[0], mod)
+        in_type = _infer_type(inputs[0], mod)
+        data_shape = get_const_tuple(in_type.checked_type.shape)
         data_dim = len(data_shape)
         stride_dim = len(stride)
 
+        # This is a special routine to handle strided_slice after shape_of.
+        # We need this since in some cases we want to do strided_slice on
+        # a partial symbolic shape, such as (1, ?), and get a static shape
+        # (1,). Directly slice on shape_of will result in fully dynamic shape.
+        # TODO(kevinthesun): Can we generalize this process with partial eval?
+        if isinstance(inputs[0], _expr.Call) and "shape_of" in str(inputs[0].op):

Review comment:
       inputs[0].op == op.get("shape_of") ?

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2395,29 +2410,40 @@ def _get_abs_layer_name(node):
 # 1.x.
 _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
 
-# A map to record tensor array with fixed rank shape
-_static_tensor_array_map = {}
-
-class RewriteSubgraph(ExprMutator):
-    """
-    A helper class to rewrite expr in while loop function to variable
-
-    Parameters
-    ----------
-    rewrite_map : Dict[expr, expr]
-        A dictionay contains a set of expr to var mapping.
-    """
-    def __init__(self, rewrite_map):
-        ExprMutator.__init__(self)
-        self.rewrite_map = rewrite_map
-
-    def visit(self, expr):
-        if expr in self.rewrite_map:
-            return self.rewrite_map[expr]
-        return super().visit(expr)
+# A map to record tensor array write ops and input ta/tensor indices
+# Value is (index of tensor array, index of written node)
+_tensor_array_write_ops = {
+    "TensorArrayWrite"   : (3, 2),
+    "TensorArrayScatter" : (0, 2),
+    "TensorArraySplit"   : (0, 1),
+}
 
-def rewrite_subgraph(expr, rewrites):
-    return RewriteSubgraph(rewrites).visit(expr)
+def is_tensor_array_constuctor(tf_node):
+    """Check whether is tensor array constructor node."""
+    is_ta = False
+    ta_start = "TensorArrayV"
+    if tf_node.op.startswith(ta_start):
+        try:
+            int(tf_node.op[len(ta_start)])

Review comment:
       the value is not used?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org