You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/04 00:09:50 UTC

[tvm] branch main updated: [AutoScheduler] Querying and sampling in task extraction (#7571)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f5f84d  [AutoScheduler] Querying and sampling in task extraction (#7571)
3f5f84d is described below

commit 3f5f84d2e27225a188588450fd744516122d9a67
Author: Cody Yu <co...@gmail.com>
AuthorDate: Wed Mar 3 16:09:38 2021 -0800

    [AutoScheduler] Querying and sampling in task extraction (#7571)
    
    * [AutoScheduler] Query in task extraction
    
    * trigger ci
---
 python/tvm/auto_scheduler/relay_integration.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index b39aba2..68f5312 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -283,10 +283,13 @@ def auto_schedule_topi(outs):
     key = register_workload_tensors(dag.workload_key(), io_tensors)
     target = tvm.target.Target.current()
 
+    dispatch_ctx = DispatchContext.current
+    state = dispatch_ctx.query(target, key, has_complex_op, dag)
+    schedule = None
+
     env = TracingEnvironment.current
     if env is None:
         # in the final build mode
-        state = DispatchContext.current.query(target, key, has_complex_op, dag)
         if state is None:
             return None
 
@@ -303,8 +306,6 @@ def auto_schedule_topi(outs):
             LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE
             and has_layout_free
         ):
-            dispatch_ctx = DispatchContext.current
-            state = dispatch_ctx.query(target, key, has_complex_op, dag)
             if state is None:
                 return None
 
@@ -316,7 +317,7 @@ def auto_schedule_topi(outs):
     else:
         raise ValueError("Invalid tracing mode: " + env.tracing_mode)
 
-    return None
+    return schedule
 
 
 def tensor_no_check_call(self, *indices):