You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/06/07 23:00:17 UTC
[tvm] 01/01: initial commit
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/build-funcs-inherit-passcontext
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 4801c02afa46eb0fe837a9e5556cc1d80cf3d1b2
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Tue Jun 7 16:00:01 2022 -0700
initial commit
---
python/tvm/auto_scheduler/measure.py | 2 +-
python/tvm/autotvm/measure/measure_methods.py | 27 ++++++++++++++++++++++-----
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 2a4a03bbe8..75f1116864 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -630,7 +630,7 @@ def _local_build_worker(inp_serialized, build_func, verbose):
filename = os.path.join(dirname, "tmp_func." + build_func.output_format)
try:
- with transform.PassContext():
+ with transform.PassContext().current():
func = build_module.build(sch, args, target=task.target)
func.export_library(filename, build_func)
# pylint: disable=broad-except
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index f582bd1974..7a398eb27d 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
if not config.valid():
raise InstantiationError(config.errors)
- opts = build_option or {}
- if check_gpu: # Add verify pass to filter out invalid configs in advance.
- opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
-
# if target is vta, we need to use vta build
if (
hasattr(measure_input.target, "device_name")
@@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
func = vta.build(s, args, target_host=task.target_host)
else:
- with tvm.ir.transform.PassContext(config=opts):
+ current_pass_context: tvm.ir.transform.PassContext = (
+ tvm.ir.transform.PassContext.current()
+ )
+ current_config = dict(current_pass_context.config)
+ if build_option is not None:
+ current_config.update(build_option)
+
+ if "tir.add_lower_pass" in current_config:
+ current_add_lower_pass = list(current_config["tir.add_lower_pass"])
+ else:
+ current_add_lower_pass = []
+ if check_gpu:
+ current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
+ current_config["tir.add_lower_pass"] = current_add_lower_pass
+
+ with tvm.ir.transform.PassContext(
+ opt_level=current_pass_context.opt_level,
+ required_pass=current_pass_context.required_pass,
+ disabled_pass=current_pass_context.disabled_pass,
+ instruments=current_pass_context.instruments,
+ config=current_config,
+ ):
func = build(s, args, target_host=task.target_host, runtime=runtime)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)