You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/09/30 23:35:17 UTC

[tvm] branch main updated: [Meta Schedule][M3a] TaskScheduler (#9154)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b4b3d0  [Meta Schedule][M3a] TaskScheduler (#9154)
4b4b3d0 is described below

commit 4b4b3d0e22a7cef4b36bfc124b59c3a2569889c2
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Sep 30 16:35:01 2021 -0700

    [Meta Schedule][M3a] TaskScheduler (#9154)
    
    * Add docs.
    
    * Add TaskScheduler.
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Retrigger CI after hotfix.
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
 include/tvm/meta_schedule/runner.h                 |   6 +-
 include/tvm/meta_schedule/task_scheduler.h         | 220 +++++++++++++++++++++
 include/tvm/meta_schedule/tune_context.h           |  15 ++
 python/tvm/meta_schedule/database/__init__.py      |   2 +-
 .../tvm/meta_schedule/search_strategy/__init__.py  |   6 +-
 .../{database => task_scheduler}/__init__.py       |  10 +-
 .../meta_schedule/task_scheduler/round_robin.py    |  64 ++++++
 .../meta_schedule/task_scheduler/task_scheduler.py | 122 ++++++++++++
 python/tvm/meta_schedule/tune_context.py           |  11 ++
 python/tvm/meta_schedule/utils.py                  |  24 ++-
 src/meta_schedule/task_scheduler/round_robin.cc    |  71 +++++++
 src/meta_schedule/task_scheduler/task_scheduler.cc | 219 ++++++++++++++++++++
 src/meta_schedule/tune_context.cc                  |   9 +-
 src/meta_schedule/utils.h                          |   1 +
 .../unittest/test_meta_schedule_task_scheduler.py  | 218 ++++++++++++++++++++
 15 files changed, 987 insertions(+), 11 deletions(-)

diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h
index a45a489..c1451ae 100644
--- a/include/tvm/meta_schedule/runner.h
+++ b/include/tvm/meta_schedule/runner.h
@@ -25,7 +25,7 @@
 namespace tvm {
 namespace meta_schedule {
 
-/*! \brief The runner's input. */
+/*! \brief Runner's input containing path of artifact, type of device and argument info. */
 class RunnerInputNode : public runtime::Object {
  public:
   /*! \brief The path to the built artifact. */
@@ -61,7 +61,7 @@ class RunnerInput : public runtime::ObjectRef {
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
 };
 
-/*! \brief The runner's output. */
+/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
 class RunnerResultNode : public runtime::Object {
  public:
   /*! \brief The run time in seconds.*/
@@ -96,7 +96,7 @@ class RunnerResult : public runtime::ObjectRef {
 /*!
  * \brief A class to asynchronously fetch runner's output.
  * \note The API design is consistent with python's concurrent.futures.Future:
- * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
+ *  https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
  */
 class RunnerFutureNode : public runtime::Object {
  public:
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
new file mode 100644
index 0000000..a2db24e
--- /dev/null
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
+#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
+
+#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/database.h>
+#include <tvm/meta_schedule/runner.h>
+#include <tvm/meta_schedule/tune_context.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+/*!
+ * \brief The abstract interface of task schedulers.
+ * \note The relationship between SpaceGenerator and other classes are as follows:
+      ┌──────────────────────────────────────────────────────────────┐
+   ┌──┴───────────────────────────────────────────────────────────┐  │
+┌──┴────────────────── Tune Context ───────────────────────────┐  │  │
+│                ┌─────────────────────┐                       │  │  │
+│                │                     │   Generate            │  │  │
+│                │   Space Generator   ├──────────────┐        │  │  │
+│                │                     │              │        │  │  │
+│                └─────────────────────┘              ▼        │  │  │
+│                                                Design Space  │  │  │
+│                ┌─────────────────────┐              │        │  │  │
+│      Generate  │                     │   Pretuning  │        │  │  │
+│    ┌───────────┤   Search Strategy   │◄─────────────┘        │  │  │
+│    │           │                     │                       │  ├──┘
+│    │           └─────────────────────┘                       ├──┘
+└────┼─────────────────────────────────────────────────────────┘
+     │
+     │
+┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
+│    │                                 ┌───────────┐                  │
+│    │                      Send to    │           │  Send to         │
+│    ▼                  ┌─────────────►│  Builder  ├──────────┐       │
+│ Measure Candidate     │   Builder    │           │  Runner  │       │
+│    │                  │              └───────────┘          │       │
+│    │     ┌────────────┴────────┐                            │       │
+│    │     │                     │     ┌───────────┐          │       │
+│    └────►│   Task Scheduler    │     │           │          │       │
+│          │                     │     │  Runner   │◄─────────┘       │
+│          └─────────────────────┘     │           │                  │
+│                   ▲                  └─────┬─────┘                  │
+│                   │                        │                        │
+│                   └───  Runner Future ◄────┘                        │
+└─────────────────────────────────────────────────────────────────────┘
+*/
+class TaskSchedulerNode : public runtime::Object {
+ public:
+  /*! \brief The tasks to be tuned */
+  Array<TuneContext> tasks;
+  /*! \brief The builder of the scheduler. */
+  Builder builder{nullptr};
+  /*! \brief The runner of the scheduler. */
+  Runner runner{nullptr};
+  /*! \brief The database of the scheduler. */
+  Database database{nullptr};
+
+  /*! \brief The default desctructor. */
+  virtual ~TaskSchedulerNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("tasks", &tasks);
+    v->Visit("builder", &builder);
+    v->Visit("runner", &runner);
+    v->Visit("database", &database);
+  }
+
+  /*! \brief Auto-tuning. */
+  virtual void Tune();
+
+  /*!
+   * \brief Set specific task to be stopped.
+   * \param task_id The task id to be stopped.
+   */
+  virtual void SetTaskStopped(int task_id);
+
+  /*!
+   * \brief Check whether the task is running.
+   * \param task_id The task id to be checked.
+   * \return Whether the task is running.
+   */
+  virtual bool IsTaskRunning(int task_id);
+
+  /*!
+   * \brief Wait until the task is finished.
+   * \param task_id The task id to be joined.
+   */
+  virtual void JoinRunningTask(int task_id);
+
+  /*!
+   * \brief Fetch the next task id.
+   * \return The next task id.
+   */
+  virtual int NextTaskId() = 0;
+
+  static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
+  TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
+};
+
+/*! \brief The task scheduler with customized methods on the python-side. */
+class PyTaskSchedulerNode : public TaskSchedulerNode {
+ public:
+  /*! \brief The function type of `Tune` method. */
+  using FTune = runtime::TypedPackedFunc<void()>;
+
+  /*!
+   * \brief The function type of `SetTaskStopped` method.
+   * \param task_id The task id to be stopped.
+   */
+  using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;
+
+  /*!
+   * \brief The function type of `IsTaskRunning` method.
+   * \param task_id The task id to be checked.
+   * \return Whether the task is running.
+   */
+  using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;
+
+  /*!
+   * \brief The function type of `JoinRunningTask` method.
+   * \param task_id The task id to be joined.
+   */
+  using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;
+
+  /*!
+   * \brief The function type of `NextTaskId` method.
+   * \return The next task id.
+   */
+  using FNextTaskId = runtime::TypedPackedFunc<int()>;
+
+  /*! \brief The packed function to the `Tune` funcion. */
+  FTune f_tune;
+  /*! \brief The packed function to the `SetTaskStopped` function. */
+  FSetTaskStopped f_set_task_stopped;
+  /*! \brief The packed function to the `IsTaskRunning` function. */
+  FIsTaskRunning f_is_task_running;
+  /*! \brief The packed function to the `JoinRunningTask` function. */
+  FJoinRunningTask f_join_running_task;
+  /*! \brief The packed function to the `NextTaskId` function. */
+  FNextTaskId f_next_task_id;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_tune` is not visited
+    // `f_set_task_stopped` is not visited
+    // `f_is_task_running` is not visited
+    // `f_join_running_task` is not visited
+    // `f_next_task_id` is not visited
+  }
+
+  void Tune() final {  //
+    f_tune();
+  }
+
+  void SetTaskStopped(int task_id) final {  //
+    f_set_task_stopped(task_id);
+  }
+
+  bool IsTaskRunning(int task_id) final {  //
+    return f_is_task_running(task_id);
+  }
+
+  void JoinRunningTask(int task_id) final {  //
+    f_join_running_task(task_id);
+  }
+
+  int NextTaskId() final {  //
+    return f_next_task_id();
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode);
+};
+
+/*!
+ * \brief Managed reference to TaskSchedulerNode.
+ * \sa TaskSchedulerNode
+ */
+class TaskScheduler : public runtime::ObjectRef {
+ public:
+  /*!
+   * \brief Create a task scheduler that fetches tasks in a round-robin fashion.
+   * \param tasks The tasks to be tuned.
+   * \param builder The builder of the scheduler.
+   * \param runner The runner of the scheduler.
+   * \param database The database of the scheduler.
+   */
+  TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
+                                          Database database);
+  TVM_DLL static TaskScheduler PyTaskScheduler(
+      PyTaskSchedulerNode::FTune f_tune,                          //
+      PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped,    //
+      PyTaskSchedulerNode::FIsTaskRunning f_is_task_running,      //
+      PyTaskSchedulerNode::FJoinRunningTask f_join_running_task,  //
+      PyTaskSchedulerNode::FNextTaskId f_next_task_id);
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
+};
+
+}  // namespace meta_schedule
+}  // namespace tvm
+
+#endif  // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 87a3a49..db72328 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -36,6 +36,8 @@ class TuneContextNode : public runtime::Object {
   Optional<Target> target;
   /*! \brief The design space generator. */
   Optional<SpaceGenerator> space_generator;
+  /*! \brief The search strategy. */
+  Optional<SearchStrategy> search_strategy;
   /*! \brief The name of the tuning task. */
   Optional<String> task_name;
   /*! \brief The random state. */
@@ -43,13 +45,24 @@ class TuneContextNode : public runtime::Object {
   /*! \brief The number of threads to be used. */
   int num_threads;
 
+  /*! \brief Whether the tuning task has been stopped or finished. */
+  bool is_stopped;
+  /*! \brief Packed functions to fetch the runner results asynchronously. */
+  Optional<Array<RunnerFuture>> runner_futures;
+  /*! \brief The measure candidates. */
+  Optional<Array<MeasureCandidate>> measure_candidates;
+
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("mod", &mod);
     v->Visit("target", &target);
     v->Visit("space_generator", &space_generator);
+    v->Visit("search_strategy", &search_strategy);
     v->Visit("task_name", &task_name);
     v->Visit("rand_state", &rand_state);
     v->Visit("num_threads", &num_threads);
+    v->Visit("is_stopped", &is_stopped);
+    v->Visit("runner_futures", &runner_futures);
+    v->Visit("measure_candidates", &measure_candidates);
   }
 
   static constexpr const char* _type_key = "meta_schedule.TuneContext";
@@ -67,6 +80,7 @@ class TuneContext : public runtime::ObjectRef {
    * \param mod The workload to be tuned.
    * \param target The target to be tuned for.
    * \param space_generator The design space generator.
+   * \param search_strategy The search strategy.
    * \param task_name The name of the tuning task.
    * \param rand_state The random state.
    * \param num_threads The number of threads to be used.
@@ -74,6 +88,7 @@ class TuneContext : public runtime::ObjectRef {
   TVM_DLL explicit TuneContext(Optional<IRModule> mod,                                    //
                                Optional<Target> target,                                   //
                                Optional<SpaceGenerator> space_generator,                  //
+                               Optional<SearchStrategy> search_strategy,                  //
                                Optional<String> task_name,                                //
                                support::LinearCongruentialEngine::TRandState rand_state,  //
                                int num_threads);
diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py
index dcd430d..320647b 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/database/__init__.py
@@ -18,5 +18,5 @@
 The tvm.meta_schedule.database package.
 The database that stores serialized tuning records and workloads
 """
-from .database import Database, PyDatabase, TuningRecord
+from .database import Database, PyDatabase, TuningRecord, Workload
 from .json_database import JSONDatabase
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py
index 40f21da..609baa2 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/search_strategy/__init__.py
@@ -14,7 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Search Strategy"""
+"""
+The tvm.meta_schedule.search_strategy package.
+Meta Schedule search strategy utilizes the design spaces given
+to generate measure candidates.
+"""
 
 from .search_strategy import SearchStrategy, PySearchStrategy
 from .replay_trace import ReplayTrace
diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/task_scheduler/__init__.py
similarity index 73%
copy from python/tvm/meta_schedule/database/__init__.py
copy to python/tvm/meta_schedule/task_scheduler/__init__.py
index dcd430d..dbfe962 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/task_scheduler/__init__.py
@@ -15,8 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 """
-The tvm.meta_schedule.database package.
-The database that stores serialized tuning records and workloads
+The tvm.meta_schedule.task_scheduler package.
+Meta Schedule task scheduler that manage the task scheduling
+for measure candidates generation and measurement, then save
+records to the database.
 """
-from .database import Database, PyDatabase, TuningRecord
-from .json_database import JSONDatabase
+from .task_scheduler import TaskScheduler, PyTaskScheduler
+from .round_robin import RoundRobin
diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py
new file mode 100644
index 0000000..391011b
--- /dev/null
+++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Round Robin Task Scheduler"""
+
+from typing import List, TYPE_CHECKING
+
+from tvm._ffi import register_object
+
+from ..builder import Builder
+from ..runner import Runner
+from ..database import Database
+from .task_scheduler import TaskScheduler
+
+from .. import _ffi_api
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+
+
+@register_object("meta_schedule.RoundRobin")
+class RoundRobin(TaskScheduler):
+    """Round Robin Task Scheduler"""
+
+    def __init__(
+        self,
+        tasks: List["TuneContext"],
+        builder: Builder,
+        runner: Runner,
+        database: Database,
+    ) -> None:
+        """Constructor.
+
+        Parameters
+        ----------
+        tasks : List[TuneContext]
+            List of tasks to schedule.
+        builder : Builder
+            The builder.
+        runner : Runner
+            The runner.
+        database : Database
+            The database.
+        """
+        self.__init_handle_by_constructor__(
+            _ffi_api.TaskSchedulerRoundRobin,  # type: ignore # pylint: disable=no-member
+            tasks,
+            builder,
+            runner,
+            database,
+        )
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
new file mode 100644
index 0000000..b8dcfd9
--- /dev/null
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Auto-tuning Task Scheduler"""
+from tvm._ffi import register_object
+from tvm.runtime import Object
+
+from .. import _ffi_api
+
+
+@register_object("meta_schedule.TaskScheduler")
+class TaskScheduler(Object):
+    """The abstract task scheduler interface."""
+
+    def tune(self) -> None:
+        """Auto-tuning."""
+        _ffi_api.TaskSchedulerTune(self)  # pylint: disable=no-member
+
+    def _set_task_stopped(self, task_id: int) -> None:
+        """Set specific task to be stopped.
+
+        Parameters
+        ----------
+        task_id : int
+            The task id to be stopped.
+        """
+        _ffi_api.TaskSchedulerSetTaskStopped(self, task_id)  # pylint: disable=no-member
+
+    def _is_task_running(self, task_id: int) -> bool:
+        """Check whether the task is running.
+
+        Parameters
+        ----------
+        task_id : int
+            The task id to be checked.
+
+        Returns
+        -------
+        bool
+            Whether the task is running.
+        """
+        return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id)  # pylint: disable=no-member
+
+    def _join_running_task(self, task_id: int) -> None:
+        """Wait until the task is finished.
+
+        Parameters
+        ----------
+        task_id : int
+            The task id to be joined.
+        """
+        _ffi_api.TaskSchedulerJoinRunningTask(self, task_id)  # pylint: disable=no-member
+
+    def _next_task_id(self) -> int:
+        """Fetch the next task id.
+
+        Returns
+        -------
+        int
+            The next task id.
+        """
+        return _ffi_api.TaskSchedulerNextTaskId(self)  # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyTaskScheduler")
+class PyTaskScheduler(TaskScheduler):
+    """An abstract task scheduler with customized methods on the python-side."""
+
+    def __init__(self):
+        """Constructor."""
+
+        def f_tune() -> None:
+            self.tune()
+
+        def f_set_task_stopped(task_id: int) -> None:
+            self._set_task_stopped(task_id)
+
+        def f_is_task_running(task_id: int) -> bool:
+            return self._is_task_running(task_id)
+
+        def f_join_running_task(task_id: int) -> None:
+            self._join_running_task(task_id)
+
+        def f_next_task_id() -> int:
+            return self._next_task_id()
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.TaskSchedulerPyTaskScheduler,  # pylint: disable=no-member
+            f_tune,
+            f_set_task_stopped,
+            f_is_task_running,
+            f_join_running_task,
+            f_next_task_id,
+        )
+
+    def tune(self) -> None:
+        raise NotImplementedError()
+
+    def _set_task_stopped(self, task_id: int) -> None:
+        _ffi_api.TaskSchedulerSetTaskStopped(self, task_id)  # pylint: disable=no-member
+
+    def _is_task_running(self, task_id: int) -> bool:
+        return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id)  # pylint: disable=no-member
+
+    def _join_running_task(self, task_id: int) -> None:
+        _ffi_api.TaskSchedulerJoinRunningTask(self, task_id)  # pylint: disable=no-member
+
+    def _next_task_id(self) -> int:
+        return _ffi_api.TaskSchedulerNextTaskId(self)  # pylint: disable=no-member
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 9c41b4d..0f3cfac 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -28,6 +28,7 @@ from . import _ffi_api
 
 if TYPE_CHECKING:
     from .space_generator import SpaceGenerator
+    from .search_strategy import SearchStrategy
 
 
 @register_object("meta_schedule.TuneContext")
@@ -45,6 +46,10 @@ class TuneContext(Object):
         The workload to be optimized.
     target : Optional[Target] = None
         The target to be optimized for.
+    space_generator : Optional[SpaceGenerator] = None
+        The design space generator.
+    search_strategy : Optional[SearchStrategy] = None
+        The search strategy.
     task_name : Optional[str] = None
         The name of the tuning task.
     rand_state : int = -1
@@ -63,6 +68,8 @@ class TuneContext(Object):
 
     mod: Optional[IRModule]
     target: Optional[Target]
+    space_generator: "SpaceGenerator"
+    search_strategy: "SearchStrategy"
     task_name: Optional[str]
     rand_state: int
     num_threads: int
@@ -72,6 +79,7 @@ class TuneContext(Object):
         mod: Optional[IRModule] = None,
         target: Optional[Target] = None,
         space_generator: Optional["SpaceGenerator"] = None,
+        search_strategy: Optional["SearchStrategy"] = None,
         task_name: Optional[str] = None,
         rand_state: int = -1,
         num_threads: Optional[int] = None,
@@ -86,6 +94,8 @@ class TuneContext(Object):
             The target to be optimized for.
         space_generator : Optional[SpaceGenerator] = None
             The design space generator.
+        search_strategy : Optional[SearchStrategy] = None
+            The search strategy.
         task_name : Optional[str] = None
             The name of the tuning task.
         rand_state : int = -1
@@ -102,6 +112,7 @@ class TuneContext(Object):
             mod,
             target,
             space_generator,
+            search_strategy,
             task_name,
             rand_state,
             num_threads,
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index 5f53699..bf2ef17 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -21,9 +21,10 @@ import shutil
 from typing import Any, Callable, List, Optional, Union
 
 import psutil
+import tvm
 from tvm._ffi import get_global_func, register_func
 from tvm.error import TVMError
-from tvm.ir import Array, Map
+from tvm.ir import Array, Map, IRModule
 from tvm.rpc import RPCSession
 from tvm.runtime import PackedFunc, String
 from tvm.tir import FloatImm, IntImm
@@ -183,3 +184,24 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]:
         for json_str in map(str.strip, json_strs)
         if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//"))
     ]
+
+
+def structural_hash(mod: IRModule) -> str:
+    """Get the structural hash of a module.
+
+    Parameters
+    ----------
+    mod : IRModule
+        The module to be hashed.
+
+    Returns
+    -------
+    result : str
+        The structural hash of the module.
+    """
+    shash = tvm.ir.structural_hash(mod)
+    if shash < 0:
+        # Workaround because `structural_hash` returns a size_t, i.e., unsigned integer
+        # but ffi can't handle unsigned integers properly so it's parsed into a negative number
+        shash += 1 << 64
+    return str(shash)
diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc
new file mode 100644
index 0000000..a529f23
--- /dev/null
+++ b/src/meta_schedule/task_scheduler/round_robin.cc
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief The round-robin style task scheduler. */
+class RoundRobinNode final : public TaskSchedulerNode {
+ public:
+  /*! \brief The current task id processed. */
+  int task_id = -1;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TaskSchedulerNode::VisitAttrs(v);
+    v->Visit("task_id", &task_id);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.RoundRobin";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RoundRobinNode, TaskSchedulerNode);
+
+ protected:
+  int NextTaskId() final {
+    int n_tasks = this->tasks.size();
+    for (int i = 0; i < n_tasks; ++i) {
+      task_id = (task_id + 1) % n_tasks;
+      TuneContext task = tasks[task_id];
+      if (!task->is_stopped) {
+        if (IsTaskRunning(task_id)) {
+          JoinRunningTask(task_id);
+        }
+        return task_id;
+      }
+    }
+    return -1;
+  }
+};
+
+TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
+                                        Database database) {
+  ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
+  n->tasks = tasks;
+  n->builder = builder;
+  n->runner = runner;
+  n->database = database;
+  n->task_id = -1;
+  return TaskScheduler(n);
+}
+
+TVM_REGISTER_NODE_TYPE(RoundRobinNode);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin")
+    .set_body_typed(TaskScheduler::RoundRobin);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
new file mode 100644
index 0000000..cf0af3d
--- /dev/null
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*!
+ * \brief Send the measure candidates to builder.
+ * \param builder The builder to send the candidates to.
+ * \param context The tuning context.
+ * \param candidates The measure candidates.
+ * \return An array of the builder results.
+ */
+Array<BuilderResult> SendToBuilder(const Builder& builder,  //
+                                   const TuneContext& context,
+                                   const Array<MeasureCandidate>& candidates) {
+  Target target = context->target.value();
+  Array<BuilderInput> inputs;
+  inputs.reserve(candidates.size());
+  for (const MeasureCandidate& candidate : candidates) {
+    inputs.push_back(BuilderInput(candidate->sch->mod(), target));
+  }
+  return builder->Build(inputs);
+}
+
+/*!
+ * \brief Send the built measure candidates to runner.
+ * \param runner The runner to send the candidates to.
+ * \param context The tuning context.
+ * \param candidates The mesure candidates.
+ * \param builder_results The builder results.
+ * \return An array of the runner results.
+ */
+Array<RunnerFuture> SendToRunner(const Runner& runner,  //
+                                 const TuneContext& context,
+                                 const Array<MeasureCandidate>& candidates,
+                                 const Array<BuilderResult>& builder_results) {
+  Target target = context->target.value();
+  ICHECK_EQ(candidates.size(), builder_results.size());
+  int n = candidates.size();
+  int n_build_errors = 0;
+  Array<RunnerInput> inputs;
+  inputs.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    const MeasureCandidate& candidate = candidates[i];
+    const BuilderResult& builder_result = builder_results[i];
+    if (builder_result->error_msg.defined()) {
+      ++n_build_errors;
+      continue;
+    }
+    inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(),
+                                 /*device_type=*/target->kind->name,
+                                 /*args_info=*/candidate->args_info));
+  }
+  Array<RunnerFuture> futures = runner->Run(inputs);
+  if (n_build_errors == 0) {
+    return futures;
+  }
+  Array<RunnerFuture> results;
+  results.reserve(n);
+  for (int i = 0, j = 0; i < n; ++i) {
+    const BuilderResult& builder_result = builder_results[i];
+    if (builder_result->error_msg.defined()) {
+      results.push_back(RunnerFuture(
+          /*f_done=*/[]() -> bool { return true; },
+          /*f_result=*/
+          [msg = builder_result->error_msg]() -> RunnerResult {
+            return RunnerResult(NullOpt, msg);
+          }));
+    } else {
+      results.push_back(futures[j++]);
+    }
+  }
+  return results;
+}
+
+void TaskSchedulerNode::Tune() {
+  for (const TuneContext& task : this->tasks) {
+    CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
+    CHECK(task->space_generator.defined())
+        << "ValueError: Require `context.space_generator`, but it is not defined";
+    CHECK(task->search_strategy.defined())
+        << "ValueError: Require `context.search_strategy`, but it is not defined";
+    IRModule mod = task->mod.value();
+    SpaceGenerator space = task->space_generator.value();
+    SearchStrategy strategy = task->search_strategy.value();
+    space->InitializeWithTuneContext(task);
+    strategy->InitializeWithTuneContext(task);
+    strategy->PreTuning(space->GenerateDesignSpace(mod));
+  }
+
+  int running_tasks = tasks.size();
+  while (running_tasks > 0) {
+    for (int task_id; (task_id = NextTaskId()) != -1;) {
+      TuneContext task = tasks[task_id];
+      ICHECK(!task->is_stopped);
+      ICHECK(!task->runner_futures.defined());
+      SearchStrategy strategy = task->search_strategy.value();
+      if (task->measure_candidates = strategy->GenerateMeasureCandidates()) {
+        Array<BuilderResult> builder_results =
+            SendToBuilder(this->builder, task, task->measure_candidates.value());
+        task->runner_futures =
+            SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results);
+      } else {
+        SetTaskStopped(task_id);
+        --running_tasks;
+      }
+    }
+    int n_tasks = this->tasks.size();
+    for (int task_id = 0; task_id < n_tasks; ++task_id)
+      if (IsTaskRunning(task_id)) {
+        TuneContext task = tasks[task_id];
+        this->JoinRunningTask(task_id);
+        task->search_strategy.value()->PostTuning();
+      }
+  }
+}
+
+void TaskSchedulerNode::SetTaskStopped(int task_id) {
+  TuneContext task = tasks[task_id];
+  ICHECK(!task->is_stopped);
+  task->is_stopped = true;
+}
+
+bool TaskSchedulerNode::IsTaskRunning(int task_id) {
+  TuneContext task = tasks[task_id];
+  if (task->is_stopped || !task->runner_futures.defined()) {
+    return false;
+  }
+  for (const RunnerFuture future : task->runner_futures.value()) {
+    if (!future->Done()) {
+      return true;
+    }
+  }
+  this->JoinRunningTask(task_id);
+  return false;
+}
+
+void TaskSchedulerNode::JoinRunningTask(int task_id) {
+  TuneContext task = tasks[task_id];
+  ICHECK(task->runner_futures.defined());
+  Array<RunnerFuture> futures = task->runner_futures.value();
+  int n = futures.size();
+  Array<RunnerResult> results;
+  results.reserve(n);
+  for (const RunnerFuture future : task->runner_futures.value()) {
+    results.push_back(future->Result());
+  }
+  task->search_strategy.value()->NotifyRunnerResults(results);
+  task->runner_futures = NullOpt;
+  // Add to database
+  ICHECK(task->measure_candidates.defined());
+  ICHECK(results.size() == task->measure_candidates.value().size());
+  int index = 0;
+  for (const RunnerResult& result : results) {
+    if (!result->error_msg.defined() && result->run_secs.defined()) {
+      Optional<tir::Trace> trace = task->measure_candidates.value()[index]->sch->trace();
+      ICHECK(trace.defined());
+      this->database->CommitTuningRecord(TuningRecord(
+          /*trace=*/trace.value(),
+          /*run_secs=*/result->run_secs.value(),
+          /*workload=*/this->database->CommitWorkload(task->mod.value()),
+          /*target=*/task->target.value(),
+          /*args_info=*/task->measure_candidates.value()[index]->args_info));
+    }
+    index++;
+  }
+}
+
+TaskScheduler TaskScheduler::PyTaskScheduler(
+    PyTaskSchedulerNode::FTune f_tune,                          //
+    PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped,    //
+    PyTaskSchedulerNode::FIsTaskRunning f_is_task_running,      //
+    PyTaskSchedulerNode::FJoinRunningTask f_join_running_task,  //
+    PyTaskSchedulerNode::FNextTaskId f_next_task_id) {
+  ObjectPtr<PyTaskSchedulerNode> n = make_object<PyTaskSchedulerNode>();
+  n->f_tune = f_tune;
+  n->f_set_task_stopped = f_set_task_stopped;
+  n->f_is_task_running = f_is_task_running;
+  n->f_join_running_task = f_join_running_task;
+  n->f_next_task_id = f_next_task_id;
+  return TaskScheduler(n);
+}
+
+TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode);
+TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode);
+TVM_REGISTER_GLOBAL("tvm.task.TaskSchedulerPyTaskScheduler")
+    .set_body_typed(TaskScheduler::PyTaskScheduler);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerSetTaskStopped")
+    .set_body_method<TaskScheduler>(&TaskSchedulerNode::SetTaskStopped);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerIsTaskRunning")
+    .set_body_method<TaskScheduler>(&TaskSchedulerNode::IsTaskRunning);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune")
+    .set_body_method<TaskScheduler>(&TaskSchedulerNode::Tune);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask")
+    .set_body_method<TaskScheduler>(&TaskSchedulerNode::JoinRunningTask);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId")
+    .set_body_method<TaskScheduler>(&TaskSchedulerNode::NextTaskId);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc
index ad82b6f..9fc9272 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -37,6 +37,7 @@ namespace meta_schedule {
 TuneContext::TuneContext(Optional<IRModule> mod,                                    //
                          Optional<Target> target,                                   //
                          Optional<SpaceGenerator> space_generator,                  //
+                         Optional<SearchStrategy> search_strategy,                  //
                          Optional<String> task_name,                                //
                          support::LinearCongruentialEngine::TRandState rand_state,  //
                          int num_threads) {
@@ -44,12 +45,16 @@ TuneContext::TuneContext(Optional<IRModule> mod,
   n->mod = mod;
   n->target = target;
   n->space_generator = space_generator;
+  n->search_strategy = search_strategy;
   n->task_name = task_name;
   if (rand_state == -1) {
     rand_state = std::random_device()();
   }
   support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state);
   n->num_threads = num_threads;
+  n->is_stopped = false;
+  n->runner_futures = NullOpt;
+  n->measure_candidates = NullOpt;
   data_ = std::move(n);
 }
 
@@ -59,10 +64,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
     .set_body_typed([](Optional<IRModule> mod,                                    //
                        Optional<Target> target,                                   //
                        Optional<SpaceGenerator> space_generator,                  //
+                       Optional<SearchStrategy> search_strategy,                  //
                        Optional<String> task_name,                                //
                        support::LinearCongruentialEngine::TRandState rand_state,  //
                        int num_threads) -> TuneContext {
-      return TuneContext(mod, target, space_generator, task_name, rand_state, num_threads);
+      return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state,
+                         num_threads);
     });
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 30294b8..83e65a5 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -26,6 +26,7 @@
 #include <tvm/meta_schedule/runner.h>
 #include <tvm/meta_schedule/search_strategy.h>
 #include <tvm/meta_schedule/space_generator.h>
+#include <tvm/meta_schedule/task_scheduler.h>
 #include <tvm/meta_schedule/tune_context.h>
 #include <tvm/node/node.h>
 #include <tvm/node/serialization.h>
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py
new file mode 100644
index 0000000..bdd504c
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test Meta Schedule Task Scheduler """
+
+from typing import List
+
+import sys
+import random
+
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.ir import IRModule
+from tvm.tir import Schedule
+from tvm.meta_schedule import TuneContext
+from tvm.meta_schedule.space_generator import ScheduleFn
+from tvm.meta_schedule.search_strategy import ReplayTrace
+from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
+from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
+from tvm.meta_schedule.task_scheduler import RoundRobin
+from tvm.meta_schedule.utils import structural_hash
+
+
+# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
+
+
+@tvm.script.tir
+class MatmulModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, (1024, 1024), "float32")
+        B = tir.match_buffer(b, (1024, 1024), "float32")
+        C = tir.match_buffer(c, (1024, 1024), "float32")
+        with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
+            with tir.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.script.tir
+class MatmulReluModule:
+    def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, (1024, 1024), "float32")
+        B = tir.match_buffer(b, (1024, 1024), "float32")
+        D = tir.match_buffer(d, (1024, 1024), "float32")
+        C = tir.alloc_buffer((1024, 1024), "float32")
+        with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
+            with tir.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+        with tir.block([1024, 1024], "relu") as [vi, vj]:
+            D[vi, vj] = tir.max(C[vi, vj], 0.0)
+
+
+@tvm.script.tir
+class BatchMatmulModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, [16, 128, 128])
+        B = tir.match_buffer(b, [16, 128, 128])
+        C = tir.match_buffer(c, [16, 128, 128])
+        with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "matmul") as [vn, vi, vj, vk]:
+            with tir.init():
+                C[vn, vi, vj] = 0.0
+            C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
+
+
+# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks
+
+
+def _schedule_matmul(sch: Schedule):
+    block = sch.get_block("matmul")
+    i, j, k = sch.get_loops(block=block)
+    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
+    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
+    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
+    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
+    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
+
+
+def _schedule_batch_matmul(sch: Schedule):
+    block = sch.get_block("matmul")
+    i, j, k, t = sch.get_loops(block=block)
+    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
+    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2])
+    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2])
+    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
+    t_0, t_1 = sch.split(loop=t, factors=[2, 512])
+    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
+
+
+class DummyRunnerFuture(RunnerFuture):
+    def done(self) -> bool:
+        return True
+
+    def result(self) -> RunnerResult:
+        return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None)
+
+
+class DummyBuilder(PyBuilder):
+    def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
+        return [BuilderResult("test_path", None) for _ in build_inputs]
+
+
+class DummyRunner(PyRunner):
+    def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+        return [DummyRunnerFuture() for _ in runner_inputs]
+
+
+class DummyDatabase(PyDatabase):
+    def __init__(self):
+        super().__init__()
+        self.records = []
+        self.workload_reg = []
+
+    def commit_tuning_record(self, record: TuningRecord) -> None:
+        self.records.append(record)
+
+    def commit_workload(self, mod: IRModule) -> Workload:
+        for workload in self.workload_reg:
+            if tvm.ir.structural_equal(workload.mod, mod):
+                return workload
+        workload = Workload(mod)
+        self.workload_reg.append(workload)
+        return workload
+
+    def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
+        return list(
+            filter(
+                lambda x: x.workload == workload,
+                sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
+            )
+        )[: int(top_k)]
+
+    def __len__(self) -> int:
+        return len(self.records)
+
+    def print_results(self) -> None:
+        print("\n".join([str(r) for r in self.records]))
+
+
+def test_meta_schedule_task_scheduler_single():
+    num_trials_per_iter = 3
+    num_trials_total = 10
+    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
+    replay = ReplayTrace(num_trials_per_iter, num_trials_total)
+    task = TuneContext(
+        MatmulModule(),
+        target=tvm.target.Target("llvm"),
+        space_generator=sch_fn,
+        search_strategy=replay,
+        task_name="Test",
+        rand_state=42,
+    )
+    database = DummyDatabase()
+    round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database)
+    round_robin.tune()
+    assert len(database) == num_trials_total
+
+
+def test_meta_schedule_task_scheduler_multiple():
+    num_trials_per_iter = 6
+    num_trials_total = 101
+    tasks = [
+        TuneContext(
+            MatmulModule(),
+            target=tvm.target.Target("llvm"),
+            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
+            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
+            task_name="Matmul",
+            rand_state=42,
+        ),
+        TuneContext(
+            MatmulReluModule(),
+            target=tvm.target.Target("llvm"),
+            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
+            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
+            task_name="MatmulRelu",
+            rand_state=0xDEADBEEF,
+        ),
+        TuneContext(
+            BatchMatmulModule(),
+            target=tvm.target.Target("llvm"),
+            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
+            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
+            task_name="BatchMatmul",
+            rand_state=0x114514,
+        ),
+    ]
+    database = DummyDatabase()
+    round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database)
+    round_robin.tune()
+    assert len(database) == num_trials_total * len(tasks)
+    print(database.workload_reg)
+    for task in tasks:
+        assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))