You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/06/07 23:00:17 UTC

[tvm] 01/01: initial commit

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/build-funcs-inherit-passcontext
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 4801c02afa46eb0fe837a9e5556cc1d80cf3d1b2
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Tue Jun 7 16:00:01 2022 -0700

    initial commit
---
 python/tvm/auto_scheduler/measure.py          |  2 +-
 python/tvm/autotvm/measure/measure_methods.py | 27 ++++++++++++++++++++++-----
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 2a4a03bbe8..75f1116864 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -630,7 +630,7 @@ def _local_build_worker(inp_serialized, build_func, verbose):
         filename = os.path.join(dirname, "tmp_func." + build_func.output_format)
 
         try:
-            with transform.PassContext():
+            with transform.PassContext().current():
                 func = build_module.build(sch, args, target=task.target)
             func.export_library(filename, build_func)
         # pylint: disable=broad-except
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index f582bd1974..7a398eb27d 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
         if not config.valid():
             raise InstantiationError(config.errors)
 
-        opts = build_option or {}
-        if check_gpu:  # Add verify pass to filter out invalid configs in advance.
-            opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
-
         # if target is vta, we need to use vta build
         if (
             hasattr(measure_input.target, "device_name")
@@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
 
             func = vta.build(s, args, target_host=task.target_host)
         else:
-            with tvm.ir.transform.PassContext(config=opts):
+            current_pass_context: tvm.ir.transform.PassContext = (
+                tvm.ir.transform.PassContext.current()
+            )
+            current_config = dict(current_pass_context.config)
+            if build_option is not None:
+                current_config.update(build_option)
+
+            if "tir.add_lower_pass" in current_config:
+                current_add_lower_pass = list(current_config["tir.add_lower_pass"])
+            else:
+                current_add_lower_pass = []
+            if check_gpu:
+                current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
+            current_config["tir.add_lower_pass"] = current_add_lower_pass
+
+            with tvm.ir.transform.PassContext(
+                opt_level=current_pass_context.opt_level,
+                required_pass=current_pass_context.required_pass,
+                disabled_pass=current_pass_context.disabled_pass,
+                instruments=current_pass_context.instruments,
+                config=current_config,
+            ):
                 func = build(s, args, target_host=task.target_host, runtime=runtime)
     return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)