You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/06/07 23:00:16 UTC
[tvm] branch aluo/build-funcs-inherit-passcontext created (now 4801c02afa)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a change to branch aluo/build-funcs-inherit-passcontext
in repository https://gitbox.apache.org/repos/asf/tvm.git
at 4801c02afa initial commit
This branch includes the following new commits:
new 4801c02afa initial commit
The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
[tvm] 01/01: initial commit
Posted by an...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/build-funcs-inherit-passcontext
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 4801c02afa46eb0fe837a9e5556cc1d80cf3d1b2
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Tue Jun 7 16:00:01 2022 -0700
initial commit
---
python/tvm/auto_scheduler/measure.py | 2 +-
python/tvm/autotvm/measure/measure_methods.py | 27 ++++++++++++++++++++++-----
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 2a4a03bbe8..75f1116864 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -630,7 +630,7 @@ def _local_build_worker(inp_serialized, build_func, verbose):
filename = os.path.join(dirname, "tmp_func." + build_func.output_format)
try:
- with transform.PassContext():
+ with transform.PassContext().current():
func = build_module.build(sch, args, target=task.target)
func.export_library(filename, build_func)
# pylint: disable=broad-except
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index f582bd1974..7a398eb27d 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
if not config.valid():
raise InstantiationError(config.errors)
- opts = build_option or {}
- if check_gpu: # Add verify pass to filter out invalid configs in advance.
- opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
-
# if target is vta, we need to use vta build
if (
hasattr(measure_input.target, "device_name")
@@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
func = vta.build(s, args, target_host=task.target_host)
else:
- with tvm.ir.transform.PassContext(config=opts):
+ current_pass_context: tvm.ir.transform.PassContext = (
+ tvm.ir.transform.PassContext.current()
+ )
+ current_config = dict(current_pass_context.config)
+ if build_option is not None:
+ current_config.update(build_option)
+
+ if "tir.add_lower_pass" in current_config:
+ current_add_lower_pass = list(current_config["tir.add_lower_pass"])
+ else:
+ current_add_lower_pass = []
+ if check_gpu:
+ current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
+ current_config["tir.add_lower_pass"] = current_add_lower_pass
+
+ with tvm.ir.transform.PassContext(
+ opt_level=current_pass_context.opt_level,
+ required_pass=current_pass_context.required_pass,
+ disabled_pass=current_pass_context.disabled_pass,
+ instruments=current_pass_context.instruments,
+ config=current_config,
+ ):
func = build(s, args, target_host=task.target_host, runtime=runtime)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)