You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/03/01 00:38:31 UTC

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4964: [Torch] Add initial control flow support

masahi commented on a change in pull request #4964: [Torch] Add initial control flow support 
URL: https://github.com/apache/incubator-tvm/pull/4964#discussion_r386066342
 
 

 ##########
 File path: python/tvm/relay/frontend/pytorch.py
 ##########
 @@ -955,7 +1025,100 @@ def parse_params(graph, state_dict):
     return params, param_tensors
 
 
-def parse_operators(operators, outputs, output_index_map, ret_name):
+def convert_block(block, outputs, output_index_map):
+    """ Translate Torch "Block", used for prim::If and prim::Loop """
+    ops = _get_operator_nodes(block.nodes())
+    ret_names = _get_input_names(block.returnNode())
+    return convert_operators(ops, outputs, output_index_map, ret_names)
+
+
+def convert_if(if_node, outputs, output_index_map):
+    """ Translate Torch prim::If to Relay If """
+    cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+    blocks = list(if_node.blocks())
+    true_branch = convert_block(blocks[0], outputs, output_index_map)
+    false_branch = convert_block(blocks[1], outputs, output_index_map)
+    assert len(true_branch) == 1 and len(false_branch) == 1
+    return _expr.If(cond, true_branch[0], false_branch[0])
+
+
+def convert_loop(loop_node, outputs, output_index_map):
+    """ Translate Torch prim::Loop to Relay while_loop """
+    def get_input(index):
+        ivalue = loop_node.inputsAt(index)
+        inode = ivalue.node()
+        if inode.kind() == "prim::Constant":
+            return _expr.const(_get_constant(inode))
+        var_name = ivalue.debugName()
+        assert var_name in output_index_map
+        return _wrap_const(outputs[output_index_map[var_name]])
+
+    # Refer to the spec for prim::Loop below
+    # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
+    # The first input: %max_trip_count
+    # The second input: %initial_condition
+    # The rest of input: loop variables
+    max_loop_count = get_input(0)
+    init_cond = get_input(1)
+    num_loop_var = len(list(loop_node.inputs())) - 2
+    init_vals = [get_input(i + 2) for i in range(num_loop_var)]
+
+    # For loop (not while loop) has always %initial_condition being 1
+    is_for_loop = isinstance(init_cond, _expr.Constant)
 
 Review comment:
   thanks for pointing this out, yes, it seems if the condition in in the while loop is not input dependent, such as `while i < 10:`, then init_cond seems to be constant (also see my example IR for while loop above, which shows this). `is_for_loop` becomes True and it is broken.
   
    For dynamic condition such as `while i < inp.size(0):`, init_cond is not constant. My test cases cover only such cases. I'll add while loop test with constant init cond.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services