You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/06 18:03:45 UTC

[GitHub] [tvm] kevinLu1114 commented on a change in pull request #6700: [Relay][Frontend][Onnx] Loop Support

kevinLu1114 commented on a change in pull request #6700:
URL: https://github.com/apache/tvm/pull/6700#discussion_r646167029



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -1995,6 +2022,164 @@ def _impl_v11(cls, inputs, attr, params):
         return result
 
 
+class Loop(OnnxOpConverter):
+    """Operator converter for Loop"""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        max_loop_count = inputs[0]
+        cond = inputs[1]
+        loop_deps = inputs[2:]
+        num_deps = len(loop_deps)
+        body = attr["body"]
+        iter_dtype = infer_type(max_loop_count).checked_type.dtype
+
+        # Determine what condition mode we're in.
+        assert cond is not None or max_loop_count is not None
+        is_for_loop = max_loop_count is not None and cond is None
+        is_condition_for_loop = cond is not None and max_loop_count is not None
+
+        # Loop inputs will be packed as
+        # [iter_count, max_count, condition, loop_deps, scan_outputs]
+        def cond_fn(*loop_inputs):
+            i = loop_inputs[0]
+            max_count = loop_inputs[1]
+            w = loop_inputs[2]
+
+            if cond is not None:
+                out_while = _op.equal(w, _expr.const(True, "bool"))
+            if max_loop_count is not None:
+                out_loop = _op.less(i, max_count)
+
+            if is_condition_for_loop:
+                return _op.logical_and(out_while, out_loop)
+            if is_for_loop:
+                return out_loop
+            return out_while
+
+        # Get the current graph proto and create a clone for the subgraph
+        graph_scope = GraphProto.current
+        subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype)
+        # Load nodes from outer graph into inner graph.
+        subgraph_scope._nodes = graph_scope._nodes.copy()
+
+        # Create a list of variables for each value updated in the loop.
+        def get_var(name, val, scan=False):
+            checked_type = infer_type(val)
+            if hasattr(checked_type, "type_annotation"):
+                checked_type = checked_type.type_annotation
+            shape = get_const_tuple(checked_type.shape)
+            actual_shape = []
+            for dim in shape:
+                if isinstance(dim, int) and dim == 0:
+                    actual_shape.append(_ty.Any())
+                else:
+                    actual_shape.append(dim)
+            if scan:
+                return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype)
+
+            return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype)
+
+        loop_vars = [
+            _expr.var(body.input[0].name, shape=(), dtype=iter_dtype),  # iteration count
+            _expr.var("max_count", shape=(), dtype=iter_dtype),  # iteration count
+            get_var(body.input[1].name, cond),  # exit condition
+        ]
+        loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
+        loop_var_names = [v.name_hint for v in loop_vars]
+
+        num_scan_outputs = len(body.output) - (1 + num_deps)
+        # TODO (jwfromm) Test with strided slice once type unifier for this case is fixed.

Review comment:
       Hi @jwfromm 
   Did you make this?
   I got some error.
   I'm not sure if it is because of this.
   Can you help me please
   
   The model is mobilenetv1-SSD like this:
   https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.onnx
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org