You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/06/30 23:24:47 UTC

[tvm] branch main updated: [MetaSchedule] Fix Task Extraction (#11954)

This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new beea0d2d6a [MetaSchedule] Fix Task Extraction (#11954)
beea0d2d6a is described below

commit beea0d2d6add545bd27130309aa12b8e7a38100f
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Jun 30 16:24:42 2022 -0700

    [MetaSchedule] Fix Task Extraction (#11954)
---
 python/tvm/meta_schedule/__init__.py          |  6 +++++-
 python/tvm/meta_schedule/relay_integration.py | 24 +++++++++++++++++++++++-
 python/tvm/meta_schedule/tune.py              |  5 ++++-
 python/tvm/relay/backend/te_compiler.py       |  5 +++--
 python/tvm/relay/op/strategy/cuda.py          |  8 +++-----
 src/meta_schedule/database/json_database.cc   |  2 +-
 src/relay/backend/te_compiler.cc              |  1 +
 7 files changed, 40 insertions(+), 11 deletions(-)

diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index eb40b32e7c..f60d0a5490 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -33,7 +33,11 @@ from . import (
 from .apply_history_best import ApplyHistoryBest
 from .extracted_task import ExtractedTask
 from .profiler import Profiler
-from .relay_integration import extract_task_from_relay, is_meta_schedule_enabled
+from .relay_integration import (
+    extract_task_from_relay,
+    is_meta_schedule_dispatch_enabled,
+    is_meta_schedule_enabled,
+)
 from .search_strategy import MeasureCandidate
 from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
 from .tune_context import TuneContext
diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py
index 707b469aa4..bd12ac350a 100644
--- a/python/tvm/meta_schedule/relay_integration.py
+++ b/python/tvm/meta_schedule/relay_integration.py
@@ -70,6 +70,7 @@ def extract_task_from_relay(
         The tasks extracted from this network
     """
     # pylint: disable=import-outside-toplevel
+    from tvm import autotvm
     from tvm.relay import Function as RelayFunc
 
     # pylint: enable=import-outside-toplevel
@@ -102,7 +103,14 @@ def extract_task_from_relay(
         config=pass_config,
         disabled_pass=disabled_pass,
     ):
-        return list(extract_task_func(mod, target, relay_params, te_filter_func))
+        if target.kind.name != "cuda" and isinstance(
+            autotvm.DispatchContext.current, autotvm.FallbackContext
+        ):
+            tophub_context = autotvm.tophub.context(target)
+        else:
+            tophub_context = autotvm.utils.EmptyContext()
+        with tophub_context:
+            return list(extract_task_func(mod, target, relay_params, te_filter_func))
 
 
 def is_meta_schedule_enabled() -> bool:
@@ -117,3 +125,17 @@ def is_meta_schedule_enabled() -> bool:
         "relay.backend.use_meta_schedule",
         False,
     )
+
+
+def is_meta_schedule_dispatch_enabled() -> bool:
+    """Return whether the meta-schedule dispatch is enabled.
+
+    Returns
+    -------
+    enabled: bool
+        Whether the meta schedule is enabled
+    """
+    return transform.PassContext.current().config.get(
+        "relay.backend.use_meta_schedule_dispatch",
+        False,
+    )
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index cd40429d16..bc2e7096c6 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -592,6 +592,9 @@ def tune_relay(
         with target, autotvm_silencer(), ApplyHistoryBest(database):
             with PassContext(
                 opt_level=3,
-                config={"relay.backend.use_meta_schedule": True},
+                config={
+                    "relay.backend.use_meta_schedule": True,
+                    "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda",
+                },
             ):
                 return relay_build(mod, target=target, params=params)
diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py
index 3c87f45b8f..a2fbf555e1 100644
--- a/python/tvm/relay/backend/te_compiler.py
+++ b/python/tvm/relay/backend/te_compiler.py
@@ -23,7 +23,8 @@ import logging
 import numpy as np
 import tvm
 from tvm import autotvm, te
-from tvm.ir.transform import PassContext
+from tvm.auto_scheduler import is_auto_scheduler_enabled
+from tvm.meta_schedule import is_meta_schedule_dispatch_enabled
 from tvm.runtime import Object
 from tvm.support import libinfo
 from tvm.target import Target
@@ -180,7 +181,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
 
     # Disable autotvm if auto_scheduler is enabled.
     # (i.e., always return the implementation with the highest priority for auto-scheduler).
-    if PassContext.current().config.get("relay.backend.use_auto_scheduler", False):
+    if is_auto_scheduler_enabled() or is_meta_schedule_dispatch_enabled():
         use_autotvm = False
 
     # If not use autotvm, always return the implementation with the highest priority
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 072b958da2..9c4a896d57 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -252,9 +252,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                 )
 
             # register auto-scheduler implementations
-            if (
-                is_auto_scheduler_enabled() or is_meta_schedule_enabled()
-            ) and judge_winograd_auto_scheduler:
+            if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
                     naive_schedule,  # this implementation should never be picked by autotvm
@@ -545,7 +543,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
                 name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
             )
 
-        if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
+        if is_auto_scheduler_enabled():
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
                 naive_schedule,  # this implementation should never be picked by autotvm
@@ -823,7 +821,7 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target):
     """Matmul cuda strategy."""
     strategy = _op.OpStrategy()
 
-    if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
+    if is_auto_scheduler_enabled():
         strategy.add_implementation(
             wrap_compute_matmul(topi.nn.matmul),
             naive_schedule,
diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index 23ecb121f4..5e7c9119c9 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -204,7 +204,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
             LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1)
                        << " of file " << path_tuning_record << ". The workload is:\n"
                        << (workload.defined() ? tir::AsTVMScript(workload) : "(null)")
-                       << "\nThe JSONObject of TuningRecrod is:\n"
+                       << "\nThe JSONObject of TuningRecord is:\n"
                        << json_obj << "\nThe error message is:\n"
                        << e.what();
           }
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 210f77330a..8ca5a32b7f 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -552,6 +552,7 @@ TECompiler& TECompiler::Global() {
 }
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool);
 
 TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
   return TECompiler::Global();