You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/15 06:51:58 UTC

[GitHub] [tvm] zxybazh commented on a diff in pull request #10986: [MetaSchedule][Refactor] Introduce TuneConfig

zxybazh commented on code in PR #10986:
URL: https://github.com/apache/tvm/pull/10986#discussion_r851095355


##########
python/tvm/meta_schedule/tune.py:
##########
@@ -449,95 +336,190 @@ def _mutator_probs(
         # pylint: enable=protected-access
         raise ValueError(f"Unsupported target: {target}")
 
-    @staticmethod
-    def _tune_context(
-        tune_context: Optional[TuneContext],
-        mod: IRModule,
-        target: Target,
-        config: SearchStrategyConfig,
-        task_name: str,
-        space_generator: Optional[FnSpaceGenerator],
-        sch_rules: Optional[FnScheduleRule],
-        postprocs: Optional[FnPostproc],
-        mutator_probs: Optional[FnMutatorProb],
-        num_threads: Optional[int],
-    ) -> TuneContext:
-        if tune_context is None:
-            return TuneContext(
-                mod=mod,
-                target=target,
-                # pylint: disable=protected-access
-                space_generator=Parse._space_generator(space_generator),
-                search_strategy=config.create_strategy(),
-                sch_rules=Parse._sch_rules(sch_rules, target),
-                postprocs=Parse._postproc(postprocs, target),
-                mutator_probs=Parse._mutator_probs(mutator_probs, target),
-                # pylint: enable=protected-access
-                task_name=task_name,
-                rand_state=-1,
-                num_threads=num_threads,
-            )
-        if not isinstance(tune_context, TuneContext):
-            raise TypeError(f"Expected `tune_context` to be TuneContext, but gets: {tune_context}")
-        return tune_context
 
-    @staticmethod
-    def _task_scheduler(
-        task_scheduler: Union[None, TaskScheduler, FnTaskScheduler],
-        tasks: List[TuneContext],
-        task_weights: List[float],
-        builder: Builder,
-        runner: Runner,
-        database: Database,
-        max_trials: int,
-        cost_model: CostModel,
-        measure_callbacks: List[MeasureCallback],
-    ):
-        if task_scheduler is None:
-            return GradientBased(
-                tasks=tasks,
-                task_weights=task_weights,
-                builder=builder,
-                runner=runner,
-                database=database,
-                max_trials=max_trials,
-                cost_model=cost_model,
-                measure_callbacks=measure_callbacks,
+class TuneConfig(NamedTuple):
+    """Configuration for tuning
+
+    Parameters
+    ----------
+    max_trials_global: int
+        Maximum number of trials to run.
+    num_trials_per_iter: int
+        Number of trials to run per iteration.
+    max_trials_per_task: int
+        Maximum number of trials to run per task.
+    task_scheduler: str
+        Task scheduler to use.
+        Valid options are: round_robin, gradient.
+    search_strategy: str
+        Search strategy to use.
+        Valid options are: evolutionary, replay_func, replay_trace.
+    task_scheduler_config: Dict[str, Any]
+        Configuration for task scheduler.
+    search_strategy_config: Dict[str, Any]
+        Configuration for search strategy.
+    """
+
+    max_trials_global: int
+    num_trials_per_iter: int
+    max_trials_per_task: Optional[int] = None
+    task_scheduler: str = "gradient"
+    strategy: str = "evolutionary"
+    task_scheduler_config: Dict[str, Any] = {}
+    search_strategy_config: Dict[str, Any] = {}
+
+    def create_strategy(self, **kwargs):
+        """Create search strategy from configuration"""
+        cls_tbl = {
+            "evolutionary": EvolutionarySearch,
+            "replay_func": ReplayFunc,
+            "replay_trace": ReplayTrace,
+        }
+        if self.strategy not in cls_tbl:
+            raise ValueError(
+                f"Invalid search strategy: {self.strategy}. "
+                "Valid options are: {}".format(", ".join(cls_tbl.keys()))
             )
-        if callable(task_scheduler):
-            return task_scheduler(
-                tasks,
-                task_weights,
-                builder,
-                runner,
-                database,
-                cost_model,
-                measure_callbacks,
+        max_trials_per_task = self.max_trials_per_task
+        if max_trials_per_task is None:
+            max_trials_per_task = self.max_trials_global
+        return cls_tbl[self.strategy](
+            num_trials_per_iter=self.num_trials_per_iter,
+            max_trials_per_task=max_trials_per_task,
+            **kwargs,
+            **self.search_strategy_config,
+        )
+
+    def create_task_scheduler(self, **kwargs):
+        """Create task scheduler from configuration"""
+        cls_tbl = {
+            "round_robin": RoundRobin,
+            "gradient": GradientBased,
+        }
+        if self.task_scheduler not in cls_tbl:
+            raise ValueError(
+                f"Invalid task scheduler: {self.task_scheduler}. "
+                "Valid options are: {}".format(", ".join(cls_tbl.keys()))
             )
-        if not isinstance(task_scheduler, TaskScheduler):
-            raise TypeError(
-                f"Expected `task_scheduler` to be TaskScheduler, but gets: {task_scheduler}"
+        return cls_tbl[self.task_scheduler](
+            max_trials=self.max_trials_global,
+            **kwargs,
+            **self.task_scheduler_config,
+        )
+
+
+def tune_extracted_tasks(
+    extracted_tasks: List[ExtractedTask],
+    config: TuneConfig,
+    work_dir: str,
+    *,
+    builder: Optional[Builder] = None,
+    runner: Optional[Runner] = None,
+    database: Optional[Database] = None,
+    cost_model: Optional[CostModel] = None,
+    measure_callbacks: Optional[List[MeasureCallback]] = None,
+    space: Optional[FnSpaceGenerator] = None,
+    sch_rules: Optional[FnScheduleRule] = None,
+    postprocs: Optional[FnPostproc] = None,
+    mutator_probs: Optional[FnMutatorProb] = None,
+    num_threads: Optional[int] = None,
+) -> Database:
+    """Tune extracted tasks with a given target.
+
+    Parameters
+    ----------
+    extracted_tasks : List[ExtractedTask]
+        The list of extraced tasks.
+    config : TuneConfig
+        The search strategy config.
+    work_dir : Optional[str]
+        The working directory to save intermediate results.
+    builder : Optional[Builder]
+        The builder to use.
+    runner : Optional[Runner]
+        The runner to use.
+    database : Optional[Database]
+        The database to use.
+    cost_model : Optional[CostModel]
+        The cost model to use.
+    measure_callbacks : Optional[List[MeasureCallback]]
+        The callbacks used during tuning.
+    task_scheduler : Optional[TaskScheduler]
+        The task scheduler to use.
+    space : Optional[FnSpaceGenerator]
+        The space generator to use.
+    sch_rules : Optional[FnScheduleRule]
+        The search rules to use.
+    postprocs : Optional[FnPostproc]
+        The postprocessors to use.
+    mutator_probs : Optional[FnMutatorProb]
+        The probability distribution to use different mutators.
+    num_threads : Optional[int]
+        The number of threads to use.
+
+    Returns
+    -------
+    database : Database
+        The database containing all the tuning results.
+
+    """
+    logger.info("Working directory: %s", work_dir)
+    # pylint: disable=protected-access
+    database = Parse._database(database, work_dir)
+    builder = Parse._builder(builder)
+    runner = Parse._runner(runner)
+    cost_model = Parse._cost_model(cost_model)
+    measure_callbacks = Parse._callbacks(measure_callbacks)
+    # parse the tuning contexts
+    tune_contexts = []
+    for task in extracted_tasks:
+        assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
+        tune_contexts.append(
+            TuneContext(
+                mod=Parse._mod(task.dispatched[0]),
+                target=task.target,
+                space_generator=Parse._space_generator(space),
+                search_strategy=config.create_strategy(),
+                sch_rules=Parse._sch_rules(sch_rules, task.target),
+                postprocs=Parse._postproc(postprocs, task.target),
+                mutator_probs=Parse._mutator_probs(mutator_probs, task.target),
+                task_name=task.name,
+                rand_state=-1,  # TODO: random seed

Review Comment:
   Shall we add random seed as part of the user input?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org