You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/06/30 23:24:47 UTC
[tvm] branch main updated: [MetaSchedule] Fix Task Extraction (#11954)
This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new beea0d2d6a [MetaSchedule] Fix Task Extraction (#11954)
beea0d2d6a is described below
commit beea0d2d6add545bd27130309aa12b8e7a38100f
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Jun 30 16:24:42 2022 -0700
[MetaSchedule] Fix Task Extraction (#11954)
---
python/tvm/meta_schedule/__init__.py | 6 +++++-
python/tvm/meta_schedule/relay_integration.py | 24 +++++++++++++++++++++++-
python/tvm/meta_schedule/tune.py | 5 ++++-
python/tvm/relay/backend/te_compiler.py | 5 +++--
python/tvm/relay/op/strategy/cuda.py | 8 +++-----
src/meta_schedule/database/json_database.cc | 2 +-
src/relay/backend/te_compiler.cc | 1 +
7 files changed, 40 insertions(+), 11 deletions(-)
diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index eb40b32e7c..f60d0a5490 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -33,7 +33,11 @@ from . import (
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .profiler import Profiler
-from .relay_integration import extract_task_from_relay, is_meta_schedule_enabled
+from .relay_integration import (
+ extract_task_from_relay,
+ is_meta_schedule_dispatch_enabled,
+ is_meta_schedule_enabled,
+)
from .search_strategy import MeasureCandidate
from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
from .tune_context import TuneContext
diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py
index 707b469aa4..bd12ac350a 100644
--- a/python/tvm/meta_schedule/relay_integration.py
+++ b/python/tvm/meta_schedule/relay_integration.py
@@ -70,6 +70,7 @@ def extract_task_from_relay(
The tasks extracted from this network
"""
# pylint: disable=import-outside-toplevel
+ from tvm import autotvm
from tvm.relay import Function as RelayFunc
# pylint: enable=import-outside-toplevel
@@ -102,7 +103,14 @@ def extract_task_from_relay(
config=pass_config,
disabled_pass=disabled_pass,
):
- return list(extract_task_func(mod, target, relay_params, te_filter_func))
+ if target.kind.name != "cuda" and isinstance(
+ autotvm.DispatchContext.current, autotvm.FallbackContext
+ ):
+ tophub_context = autotvm.tophub.context(target)
+ else:
+ tophub_context = autotvm.utils.EmptyContext()
+ with tophub_context:
+ return list(extract_task_func(mod, target, relay_params, te_filter_func))
def is_meta_schedule_enabled() -> bool:
@@ -117,3 +125,17 @@ def is_meta_schedule_enabled() -> bool:
"relay.backend.use_meta_schedule",
False,
)
+
+
+def is_meta_schedule_dispatch_enabled() -> bool:
+ """Return whether the meta-schedule dispatch is enabled.
+
+ Returns
+ -------
+ enabled: bool
+ Whether the meta schedule is enabled
+ """
+ return transform.PassContext.current().config.get(
+ "relay.backend.use_meta_schedule_dispatch",
+ False,
+ )
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index cd40429d16..bc2e7096c6 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -592,6 +592,9 @@ def tune_relay(
with target, autotvm_silencer(), ApplyHistoryBest(database):
with PassContext(
opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
+ config={
+ "relay.backend.use_meta_schedule": True,
+ "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda",
+ },
):
return relay_build(mod, target=target, params=params)
diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py
index 3c87f45b8f..a2fbf555e1 100644
--- a/python/tvm/relay/backend/te_compiler.py
+++ b/python/tvm/relay/backend/te_compiler.py
@@ -23,7 +23,8 @@ import logging
import numpy as np
import tvm
from tvm import autotvm, te
-from tvm.ir.transform import PassContext
+from tvm.auto_scheduler import is_auto_scheduler_enabled
+from tvm.meta_schedule import is_meta_schedule_dispatch_enabled
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target
@@ -180,7 +181,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
# Disable autotvm if auto_scheduler is enabled.
# (i.e., always return the implementation with the highest priority for auto-scheduler).
- if PassContext.current().config.get("relay.backend.use_auto_scheduler", False):
+ if is_auto_scheduler_enabled() or is_meta_schedule_dispatch_enabled():
use_autotvm = False
# If not use autotvm, always return the implementation with the highest priority
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 072b958da2..9c4a896d57 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -252,9 +252,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
)
# register auto-scheduler implementations
- if (
- is_auto_scheduler_enabled() or is_meta_schedule_enabled()
- ) and judge_winograd_auto_scheduler:
+ if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
@@ -545,7 +543,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
)
- if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
+ if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
naive_schedule, # this implementation should never be picked by autotvm
@@ -823,7 +821,7 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target):
"""Matmul cuda strategy."""
strategy = _op.OpStrategy()
- if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
+ if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index 23ecb121f4..5e7c9119c9 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -204,7 +204,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1)
<< " of file " << path_tuning_record << ". The workload is:\n"
<< (workload.defined() ? tir::AsTVMScript(workload) : "(null)")
- << "\nThe JSONObject of TuningRecrod is:\n"
+ << "\nThe JSONObject of TuningRecord is:\n"
<< json_obj << "\nThe error message is:\n"
<< e.what();
}
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 210f77330a..8ca5a32b7f 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -552,6 +552,7 @@ TECompiler& TECompiler::Global() {
}
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool);
TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
return TECompiler::Global();