You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/04 00:09:50 UTC
[tvm] branch main updated: [AutoScheduler] Querying and sampling in
task extraction (#7571)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f5f84d [AutoScheduler] Querying and sampling in task extraction (#7571)
3f5f84d is described below
commit 3f5f84d2e27225a188588450fd744516122d9a67
Author: Cody Yu <co...@gmail.com>
AuthorDate: Wed Mar 3 16:09:38 2021 -0800
[AutoScheduler] Querying and sampling in task extraction (#7571)
* [AutoScheduler] Query in task extraction
* trigger ci
---
python/tvm/auto_scheduler/relay_integration.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index b39aba2..68f5312 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -283,10 +283,13 @@ def auto_schedule_topi(outs):
key = register_workload_tensors(dag.workload_key(), io_tensors)
target = tvm.target.Target.current()
+ dispatch_ctx = DispatchContext.current
+ state = dispatch_ctx.query(target, key, has_complex_op, dag)
+ schedule = None
+
env = TracingEnvironment.current
if env is None:
# in the final build mode
- state = DispatchContext.current.query(target, key, has_complex_op, dag)
if state is None:
return None
@@ -303,8 +306,6 @@ def auto_schedule_topi(outs):
LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE
and has_layout_free
):
- dispatch_ctx = DispatchContext.current
- state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
return None
@@ -316,7 +317,7 @@ def auto_schedule_topi(outs):
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)
- return None
+ return schedule
def tensor_no_check_call(self, *indices):