You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/09 00:37:49 UTC

[tvm] branch main updated: [MetaSchedule] Restore `num_threads` parameter in tuning API (#13561)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b001efcc9 [MetaSchedule] Restore `num_threads` parameter in tuning API  (#13561)
3b001efcc9 is described below

commit 3b001efcc9dd06d6aa69b861738917e86b39874a
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Dec 9 09:37:42 2022 +0900

    [MetaSchedule] Restore `num_threads` parameter in tuning API  (#13561)
    
    * [MetaSchedule] Restore num_threads argument in tune_relay
    
    * pass num_threads to XGBModel
    
    * fix default
    
    * pass num_threads as max_workers to Builder and Runner
    
    * add test
    
    * clean up
    
    * fix kwarg
    
    * num_threads -> num_tuning_cores
    
    * typo
    
    * num_threads -> num_tuning_cores in contrib/torch
    
    * typo in document
---
 python/tvm/contrib/hexagon/meta_schedule.py        | 21 +++++++----
 python/tvm/contrib/torch/as_torch.py               |  4 +-
 python/tvm/meta_schedule/cost_model/cost_model.py  |  5 +++
 python/tvm/meta_schedule/cost_model/xgb_model.py   |  7 +++-
 python/tvm/meta_schedule/relay_integration.py      | 12 ++++--
 python/tvm/meta_schedule/runner/runner.py          |  2 +
 python/tvm/meta_schedule/tir_integration.py        |  8 ++--
 python/tvm/meta_schedule/tune.py                   | 12 ++++--
 .../metaschedule_e2e/test_resnet50_int8.py         | 43 +++++++++++-----------
 .../contrib/test_hexagon/test_meta_schedule.py     |  7 +++-
 .../test_meta_schedule_relay_integration.py        |  1 +
 11 files changed, 77 insertions(+), 45 deletions(-)

diff --git a/python/tvm/contrib/hexagon/meta_schedule.py b/python/tvm/contrib/hexagon/meta_schedule.py
index dcc7d232d8..6e1541e498 100644
--- a/python/tvm/contrib/hexagon/meta_schedule.py
+++ b/python/tvm/contrib/hexagon/meta_schedule.py
@@ -128,7 +128,9 @@ def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path
     return costs
 
 
-def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None):
+def get_hexagon_local_builder(
+    pass_context: tvm.transform.PassContext = None, max_workers: Optional[int] = None
+):
     """Return Hexagon-compatible Builder for meta schedule."""
 
     def export_func(mod):
@@ -143,13 +145,19 @@ def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None):
             return tvm_build(mod, target=target)
 
     if pass_context is not None:
-        return LocalBuilder(f_build=default_build_with_context, f_export=export_func)
+        return LocalBuilder(
+            f_build=default_build_with_context, f_export=export_func, max_workers=max_workers
+        )
     else:
-        return LocalBuilder(f_export=export_func)
+        return LocalBuilder(f_export=export_func, max_workers=max_workers)
 
 
 def get_hexagon_rpc_runner(
-    hexagon_launcher: HexagonLauncherRPC, number=3, repeat=1, min_repeat_ms=100
+    hexagon_launcher: HexagonLauncherRPC,
+    number=3,
+    repeat=1,
+    min_repeat_ms=100,
+    max_workers: Optional[int] = None,
 ):
     """Return Hexagon-compatible RPC Runner for meta schedule.
 
@@ -177,7 +185,4 @@ def get_hexagon_rpc_runner(
         enable_cpu_cache_flush=False,
     )
 
-    return HexagonRPCRunner(
-        hexagon_launcher,
-        evaluator_config,
-    )
+    return HexagonRPCRunner(hexagon_launcher, evaluator_config, max_workers=max_workers)
diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py
index 918ce3ff3b..c4ca88adf7 100644
--- a/python/tvm/contrib/torch/as_torch.py
+++ b/python/tvm/contrib/torch/as_torch.py
@@ -67,7 +67,7 @@ class OperatorModuleWrapper(torch.nn.Module):
         space: ms.SpaceGenerator.SpaceGeneratorType = "post-order-apply",
         strategy: ms.SearchStrategy.SearchStrategyType = "replay-trace",
         task_name: str = "main",
-        num_threads: Union[Literal["physical", "logical"], int] = "physical",
+        num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
         seed: Optional[int] = None,
     ) -> None:
         """
@@ -100,7 +100,7 @@ class OperatorModuleWrapper(torch.nn.Module):
                 space=space,
                 strategy=strategy,
                 task_name=task_name,
-                num_threads=num_threads,
+                num_tuning_cores=num_tuning_cores,
                 seed=seed,
             )
             sch = ms.tir_integration.compile_tir(database, self.ir_module, target)
diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py
index f139fcc4e4..c0f6ea5fb9 100644
--- a/python/tvm/meta_schedule/cost_model/cost_model.py
+++ b/python/tvm/meta_schedule/cost_model/cost_model.py
@@ -126,6 +126,11 @@ class CostModel(Object):
 
         if kind == "xgb":
             return XGBModel(*args, **kwargs)  # type: ignore
+
+        if "num_tuning_cores" in kwargs:
+            # num_tuning_cores is only relevant for XGBModel.
+            kwargs.pop("num_tuning_cores")
+
         if kind == "random":
             return RandomModel(*args, **kwargs)  # type: ignore
         if kind == "mlp":
diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py
index 0a2786c6ab..901e18ce3f 100644
--- a/python/tvm/meta_schedule/cost_model/xgb_model.py
+++ b/python/tvm/meta_schedule/cost_model/xgb_model.py
@@ -333,6 +333,7 @@ class XGBModel(PyCostModel):
         verbose_eval: int = 25,
         average_peak_n: int = 32,
         adaptive_training: bool = True,
+        num_tuning_cores: Optional[int] = None,
     ):
         super().__init__()
         if not isinstance(extractor, FeatureExtractor):
@@ -342,7 +343,11 @@ class XGBModel(PyCostModel):
         # model-related
         if config.nthread is None:
             # use physical core number
-            config = config._replace(nthread=cpu_count(logical=False))
+            if num_tuning_cores is None:
+                config = config._replace(nthread=cpu_count(logical=False))
+            else:
+                config = config._replace(nthread=num_tuning_cores)
+
         self.config = config
         # behavior of randomness
         self.num_warmup_samples = num_warmup_samples
diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py
index df76684d2d..0b8705aafe 100644
--- a/python/tvm/meta_schedule/relay_integration.py
+++ b/python/tvm/meta_schedule/relay_integration.py
@@ -180,7 +180,7 @@ def extracted_tasks_to_tune_contexts(
     work_dir: str,
     space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
     strategy: SearchStrategy.SearchStrategyType = "evolutionary",
-    num_threads: Union[Literal["physical", "logical"], int] = "physical",
+    num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
     seed: Optional[int] = None,
 ) -> Tuple[List[TuneContext], List[float]]:
     """Convert ExtractedTask to TuneContext.
@@ -195,8 +195,8 @@ def extracted_tasks_to_tune_contexts(
         The space generator to use.
     strategy : SearchStrategy.SearchStrategyType
         The search strategy to use.
-    num_threads : Union[Literal["physical", "logical"], int]
-        The number of threads to use in multi-threaded search algorithm.
+    num_tuning_cores : Union[Literal["physical", "logical"], int]
+        The number of CPU cores to use during tuning.
     seed : Optional[int]
         The random seed to use.
 
@@ -223,7 +223,7 @@ def extracted_tasks_to_tune_contexts(
                 task_name=task.task_name,
                 logger=logger,
                 rand_state=rand_state,
-                num_threads=num_threads,
+                num_threads=num_tuning_cores,
             ).clone()
         )
         task_weights.append(task.weight)
@@ -249,6 +249,7 @@ def tune_relay(
     strategy: SearchStrategy.SearchStrategyType = "evolutionary",
     seed: Optional[int] = None,
     module_equality: str = "structural",
+    num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
 ) -> Database:
     """Tune a Relay program.
 
@@ -296,6 +297,8 @@ def tune_relay(
                             given module. The "ignore-ndarray" varint is used for the extracted
                             blocks or in case no anchor block is found.
                             For the definition of the anchor block, see tir/analysis/analysis.py.
+    num_tuning_cores : Union[Literal["physical", "logical"], int]
+        The number of CPU cores to use during tuning.
 
     Returns
     -------
@@ -308,6 +311,7 @@ def tune_relay(
         space=space,
         strategy=strategy,
         seed=seed,
+        num_tuning_cores=num_tuning_cores,
     )
     return tune_tasks(
         tasks=tasks,
diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py
index 1753d8b4ab..1a8f78414e 100644
--- a/python/tvm/meta_schedule/runner/runner.py
+++ b/python/tvm/meta_schedule/runner/runner.py
@@ -194,6 +194,8 @@ class Runner(Object):
         from . import LocalRunner, RPCRunner  # pylint: disable=import-outside-toplevel
 
         if kind == "local":
+            if "max_workers" in kwargs:
+                kwargs.pop("max_workers")
             return LocalRunner(*args, **kwargs)  # type: ignore
         elif kind == "rpc":
             return RPCRunner(*args, **kwargs)  # type: ignore
diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py
index 975987ebcb..f3d505c28b 100644
--- a/python/tvm/meta_schedule/tir_integration.py
+++ b/python/tvm/meta_schedule/tir_integration.py
@@ -54,7 +54,7 @@ def tune_tir(
     space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
     strategy: SearchStrategy.SearchStrategyType = "evolutionary",
     task_name: str = "main",
-    num_threads: Union[Literal["physical", "logical"], int] = "physical",
+    num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
     seed: Optional[int] = None,
 ) -> Database:
     """Tune a TIR function.
@@ -89,8 +89,8 @@ def tune_tir(
         The search strategy.
     task_name : str
         The name of the task.
-    num_threads : Union[Literal["physical", "logical"], int]
-        The number of threads to use.
+    num_tuning_cores : Union[Literal["physical", "logical"], int]
+        The number of CPU cores to use during tuning.
     seed : Optional[int]
         The seed for the random number generator.
 
@@ -111,7 +111,7 @@ def tune_tir(
                 task_name=task_name,
                 logger=logger,
                 rand_state=seed,
-                num_threads=num_threads,
+                num_threads=num_tuning_cores,
             ).clone()
         ],
         task_weights=[1.0],
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index a69c8f1262..0c4035844c 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -86,22 +86,28 @@ def tune_tasks(
     database : Database
         The database with all tuning records
     """
+    if len(tasks) == 0:
+        raise ValueError("No tasks to tune.")
+
     if len(tasks) != len(task_weights):
         raise ValueError(
             f"Length of tasks ({len(tasks)}) and task_weights ({len(task_weights)}) do not match."
         )
+
+    num_cores = tasks[0].num_threads
+
     if max_trials_per_task is None:
         max_trials_per_task = max_trials_global
     if not isinstance(builder, Builder):
-        builder = Builder.create(builder)
+        builder = Builder.create(builder, max_workers=num_cores)
     if not isinstance(runner, Runner):
-        runner = Runner.create(runner)
+        runner = Runner.create(runner, max_workers=num_cores)
     if database == "json":
         database = Database.create(database, work_dir=work_dir, module_equality=module_equality)
     elif not isinstance(database, Database):
         database = Database.create(database, module_equality=module_equality)
     if not isinstance(cost_model, CostModel):
-        cost_model = CostModel.create(cost_model)
+        cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores)
     if isinstance(measure_callbacks, MeasureCallback):
         measure_callbacks = [measure_callbacks]
     elif measure_callbacks == "default":
diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index e15b0a4e7d..1e01cb28a7 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -33,6 +33,7 @@ from tvm.contrib.hexagon.meta_schedule import (
     get_hexagon_rpc_runner,
 )
 from tvm.meta_schedule import postproc, schedule_rule
+from tvm.meta_schedule.utils import cpu_count
 from tvm.tir.schedule import BlockRV, Schedule
 from tvm.tir.schedule.analysis import has_block
 from tvm.tir.tensor_intrin.hexagon import (
@@ -44,10 +45,24 @@ from tvm.tir.tensor_intrin.hexagon import (
 from ..infrastructure import get_hexagon_target
 
 MODEL_JSON = "resnet50_int8.json"
+MODEL_PARAMS = "resnet50_int8.params"
 EXECUTOR = relay.backend.Executor("graph", {"link-params": True})
 TARGET_LLVM = tvm.target.Target("llvm")
 TARGET_HEXAGON = get_hexagon_target("v68")
-MODEL_PARAMS = "resnet50_int8.params"
+
+
+def load_model():
+    """Load renset50 model."""
+    if not os.path.exists(MODEL_JSON):
+        pytest.skip(msg="Run python export_models.py first.")
+
+    with open(MODEL_JSON, "r") as file:
+        mod = tvm.ir.load_json(file.read())
+
+    with open(MODEL_PARAMS, "rb") as file:
+        params = relay.load_param_dict(file.read())
+
+    return mod, params
 
 
 def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
@@ -110,6 +125,8 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
     # task extraction and relay.build(...).
     mod = mod.with_attr("executor", EXECUTOR)
 
+    num_cores = cpu_count(logical=False)
+
     with tempfile.TemporaryDirectory() as work_dir:
         database = ms.relay_integration.tune_relay(
             mod=mod,
@@ -125,8 +142,8 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
             # num_trials_per_iter=32,
             # max_trials_per_task=128,
             # strategy="evolutionary",
-            builder=get_hexagon_local_builder(),
-            runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
+            builder=get_hexagon_local_builder(max_workers=num_cores),
+            runner=get_hexagon_rpc_runner(hexagon_launcher, number=20, max_workers=num_cores),
             space=ms.space_generator.PostOrderApply(
                 sch_rules=sch_rules,
                 postprocs=postprocs,
@@ -137,6 +154,7 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
             # It reduces the number of conv2d tuning tasks in the int8 resnet50 model
             # from 36 to 23, with negligible performance difference.
             module_equality="anchor-block",
+            num_tuning_cores=num_cores,
         )
         return ms.relay_integration.compile_relay(
             database=database,
@@ -156,11 +174,8 @@ def test_resnet50(hexagon_launcher):
     if not os.path.exists(MODEL_JSON):
         pytest.skip(msg="Run python export_models.py first.")
 
-    with open(MODEL_JSON, "r") as file:
-        mod = tvm.ir.load_json(file.read())
+    mod, params = load_model()
 
-    with open(MODEL_PARAMS, "rb") as file:
-        params = relay.load_param_dict(file.read())
     inp = np.random.randn(1, 3, 224, 224).astype("float32")
     input_name = "image"
 
@@ -231,20 +246,6 @@ def evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, in
         np.testing.assert_allclose(ref_result, output, atol=1e-4, rtol=1e-5)
 
 
-def load_model():
-    """Load renset50 model."""
-    if not os.path.exists(MODEL_JSON):
-        pytest.skip(msg="Run python export_models.py first.")
-
-    with open(MODEL_JSON, "r") as file:
-        mod = tvm.ir.load_json(file.read())
-
-    with open(MODEL_PARAMS, "rb") as file:
-        params = relay.load_param_dict(file.read())
-
-    return mod, params
-
-
 def _schedule_packed_8x8x32_conv2d():
     """Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc,
     using 8x8x32 packed layout.
diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py
index a83a3b279a..1089f0f035 100644
--- a/tests/python/contrib/test_hexagon/test_meta_schedule.py
+++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py
@@ -73,8 +73,11 @@ def test_builder_runner(hexagon_launcher):
 
     mod = MatmulModule
 
-    builder = get_hexagon_local_builder()
-    runner = get_hexagon_rpc_runner(hexagon_launcher, number=1, repeat=1, min_repeat_ms=0)
+    max_workers = 4
+    builder = get_hexagon_local_builder(max_workers=max_workers)
+    runner = get_hexagon_rpc_runner(
+        hexagon_launcher, number=1, repeat=1, min_repeat_ms=0, max_workers=max_workers
+    )
 
     (builder_result,) = builder.build([BuilderInput(mod, get_hexagon_target("v68"))])
     assert builder_result.artifact_path is not None
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py
index 021db0f86a..062da0b00c 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -742,6 +742,7 @@ def _test_anchor_tuning(target):
             max_trials_global=4,
             strategy="replay-trace",
             module_equality=module_equality,
+            num_tuning_cores=4,
         )
         lib = ms.relay_integration.compile_relay(database, mod, target, params)