You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/06/07 23:00:16 UTC

[tvm] branch aluo/build-funcs-inherit-passcontext created (now 4801c02afa)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a change to branch aluo/build-funcs-inherit-passcontext
in repository https://gitbox.apache.org/repos/asf/tvm.git


      at 4801c02afa initial commit

This branch includes the following new commits:

     new 4801c02afa initial commit

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 01/01: initial commit

Posted by an...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/build-funcs-inherit-passcontext
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 4801c02afa46eb0fe837a9e5556cc1d80cf3d1b2
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Tue Jun 7 16:00:01 2022 -0700

    initial commit
---
 python/tvm/auto_scheduler/measure.py          |  2 +-
 python/tvm/autotvm/measure/measure_methods.py | 27 ++++++++++++++++++++++-----
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 2a4a03bbe8..75f1116864 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -630,7 +630,7 @@ def _local_build_worker(inp_serialized, build_func, verbose):
         filename = os.path.join(dirname, "tmp_func." + build_func.output_format)
 
         try:
-            with transform.PassContext():
+            with transform.PassContext().current():
                 func = build_module.build(sch, args, target=task.target)
             func.export_library(filename, build_func)
         # pylint: disable=broad-except
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index f582bd1974..7a398eb27d 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
         if not config.valid():
             raise InstantiationError(config.errors)
 
-        opts = build_option or {}
-        if check_gpu:  # Add verify pass to filter out invalid configs in advance.
-            opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
-
         # if target is vta, we need to use vta build
         if (
             hasattr(measure_input.target, "device_name")
@@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
 
             func = vta.build(s, args, target_host=task.target_host)
         else:
-            with tvm.ir.transform.PassContext(config=opts):
+            current_pass_context: tvm.ir.transform.PassContext = (
+                tvm.ir.transform.PassContext.current()
+            )
+            current_config = dict(current_pass_context.config)
+            if build_option is not None:
+                current_config.update(build_option)
+
+            if "tir.add_lower_pass" in current_config:
+                current_add_lower_pass = list(current_config["tir.add_lower_pass"])
+            else:
+                current_add_lower_pass = []
+            if check_gpu:
+                current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
+            current_config["tir.add_lower_pass"] = current_add_lower_pass
+
+            with tvm.ir.transform.PassContext(
+                opt_level=current_pass_context.opt_level,
+                required_pass=current_pass_context.required_pass,
+                disabled_pass=current_pass_context.disabled_pass,
+                instruments=current_pass_context.instruments,
+                config=current_config,
+            ):
                 func = build(s, args, target_host=task.target_host, runtime=runtime)
     return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)