You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/17 08:10:09 UTC

[GitHub] [tvm] SebastianBoblestETAS commented on a diff in pull request #11334: [Runtime][PipelineExecutor] Add graph manually splitting logic into the unit test.

SebastianBoblestETAS commented on code in PR #11334:
URL: https://github.com/apache/tvm/pull/11334#discussion_r874504485


##########
tests/python/relay/test_pipeline_executor.py:
##########
@@ -22,12 +22,195 @@
 import tvm
 import tvm.testing
 from tvm import relay
-from tvm.relay import transform
+from tvm.relay import transform, build_module
+from tvm.relay.testing import run_opt_pass
 from tvm.contrib import graph_executor, pipeline_executor, pipeline_executor_build
 from tvm._ffi import get_global_func
 from tvm.contrib import cc as _cc
 
 
+"""Split graph into a list of subgraph"""
+
+
+def graph_split(expr, split_conf, params=None):
+    def get_dep_var(sub_var_dep):
+        return [var for var, _ in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"].items()]
+
+    def parse_dependency(value, snode_dep, new_input_idx):
+        new_args = []
+        need_update = False
+        for var in value.args:
+            is_free_var = False
+            for i in range(0, len(snode_dep) - 1):
+                dep = snode_dep[i]
+                if var in dep["nodes"]:
+                    # Mark the previous subgraph node as a dependency of this subgraph node
+                    dep["nodes"][var] = dep["nodes"][var] + 1
+                    dep["ref_nodes"][var] = dep["nodes"][var]
+                    # The var of this call is a free_var
+                    is_free_var = True
+            # if the var of this call is free_var, recreate it and give it a fixed input name.
+            if is_free_var:
+                need_update = True
+                new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type))
+                new_input_idx = new_input_idx + 1
+            else:
+                new_args.append(var)
+        # if the call have a free_var recreate it
+        if need_update:
+            value = tvm.relay.expr.Call(
+                value.op, new_args, value.attrs, value.type_args, value.span
+            )
+        return value, snode_dep, new_input_idx
+
+    def merge_constant_expr(constant_expr, expr):
+        # merge constant express with a express
+        # If body not let, then reached end of the express
+        if not isinstance(constant_expr.body, tvm.relay.expr.Let):
+            return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr)
+
+        return tvm.relay.expr.Let(
+            constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr)
+        )
+
+    def _recursion(anf, pipeline_mods, split_conf, constant_expr):
+        # Enumrate all operator of compute graph then split the compute graph into a group subgraph.

Review Comment:
   
   ```suggestion
           # Enumerate all operators of compute graph then split the compute graph into a group subgraph.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org