You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/09/30 23:35:17 UTC
[tvm] branch main updated: [Meta Schedule][M3a] TaskScheduler
(#9154)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b4b3d0 [Meta Schedule][M3a] TaskScheduler (#9154)
4b4b3d0 is described below
commit 4b4b3d0e22a7cef4b36bfc124b59c3a2569889c2
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Sep 30 16:35:01 2021 -0700
[Meta Schedule][M3a] TaskScheduler (#9154)
* Add docs.
* Add TaskScheduler.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Retrigger CI after hotfix.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
include/tvm/meta_schedule/runner.h | 6 +-
include/tvm/meta_schedule/task_scheduler.h | 220 +++++++++++++++++++++
include/tvm/meta_schedule/tune_context.h | 15 ++
python/tvm/meta_schedule/database/__init__.py | 2 +-
.../tvm/meta_schedule/search_strategy/__init__.py | 6 +-
.../{database => task_scheduler}/__init__.py | 10 +-
.../meta_schedule/task_scheduler/round_robin.py | 64 ++++++
.../meta_schedule/task_scheduler/task_scheduler.py | 122 ++++++++++++
python/tvm/meta_schedule/tune_context.py | 11 ++
python/tvm/meta_schedule/utils.py | 24 ++-
src/meta_schedule/task_scheduler/round_robin.cc | 71 +++++++
src/meta_schedule/task_scheduler/task_scheduler.cc | 219 ++++++++++++++++++++
src/meta_schedule/tune_context.cc | 9 +-
src/meta_schedule/utils.h | 1 +
.../unittest/test_meta_schedule_task_scheduler.py | 218 ++++++++++++++++++++
15 files changed, 987 insertions(+), 11 deletions(-)
diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h
index a45a489..c1451ae 100644
--- a/include/tvm/meta_schedule/runner.h
+++ b/include/tvm/meta_schedule/runner.h
@@ -25,7 +25,7 @@
namespace tvm {
namespace meta_schedule {
-/*! \brief The runner's input. */
+/*! \brief Runner's input containing path of artifact, type of device and argument info. */
class RunnerInputNode : public runtime::Object {
public:
/*! \brief The path to the built artifact. */
@@ -61,7 +61,7 @@ class RunnerInput : public runtime::ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
};
-/*! \brief The runner's output. */
+/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
class RunnerResultNode : public runtime::Object {
public:
/*! \brief The run time in seconds.*/
@@ -96,7 +96,7 @@ class RunnerResult : public runtime::ObjectRef {
/*!
* \brief A class to asynchronously fetch runner's output.
* \note The API design is consistent with python's concurrent.futures.Future:
- * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
+ * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
*/
class RunnerFutureNode : public runtime::Object {
public:
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
new file mode 100644
index 0000000..a2db24e
--- /dev/null
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
+#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
+
+#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/database.h>
+#include <tvm/meta_schedule/runner.h>
+#include <tvm/meta_schedule/tune_context.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+/*!
+ * \brief The abstract interface of task schedulers.
+ * \note The relationship between SpaceGenerator and other classes are as follows:
+ ┌──────────────────────────────────────────────────────────────┐
+ ┌──┴───────────────────────────────────────────────────────────┐ │
+┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
+│ ┌─────────────────────┐ │ │ │
+│ │ │ Generate │ │ │
+│ │ Space Generator ├──────────────┐ │ │ │
+│ │ │ │ │ │ │
+│ └─────────────────────┘ ▼ │ │ │
+│ Design Space │ │ │
+│ ┌─────────────────────┐ │ │ │ │
+│ Generate │ │ Pretuning │ │ │ │
+│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
+│ │ │ │ │ ├──┘
+│ │ └─────────────────────┘ ├──┘
+└────┼─────────────────────────────────────────────────────────┘
+ │
+ │
+┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
+│ │ ┌───────────┐ │
+│ │ Send to │ │ Send to │
+│ ▼ ┌─────────────►│ Builder ├──────────┐ │
+│ Measure Candidate │ Builder │ │ Runner │ │
+│ │ │ └───────────┘ │ │
+│ │ ┌────────────┴────────┐ │ │
+│ │ │ │ ┌───────────┐ │ │
+│ └────►│ Task Scheduler │ │ │ │ │
+│ │ │ │ Runner │◄─────────┘ │
+│ └─────────────────────┘ │ │ │
+│ ▲ └─────┬─────┘ │
+│ │ │ │
+│ └─── Runner Future ◄────┘ │
+└─────────────────────────────────────────────────────────────────────┘
+*/
+class TaskSchedulerNode : public runtime::Object {
+ public:
+ /*! \brief The tasks to be tuned */
+ Array<TuneContext> tasks;
+ /*! \brief The builder of the scheduler. */
+ Builder builder{nullptr};
+ /*! \brief The runner of the scheduler. */
+ Runner runner{nullptr};
+ /*! \brief The database of the scheduler. */
+ Database database{nullptr};
+
+ /*! \brief The default desctructor. */
+ virtual ~TaskSchedulerNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("tasks", &tasks);
+ v->Visit("builder", &builder);
+ v->Visit("runner", &runner);
+ v->Visit("database", &database);
+ }
+
+ /*! \brief Auto-tuning. */
+ virtual void Tune();
+
+ /*!
+ * \brief Set specific task to be stopped.
+ * \param task_id The task id to be stopped.
+ */
+ virtual void SetTaskStopped(int task_id);
+
+ /*!
+ * \brief Check whether the task is running.
+ * \param task_id The task id to be checked.
+ * \return Whether the task is running.
+ */
+ virtual bool IsTaskRunning(int task_id);
+
+ /*!
+ * \brief Wait until the task is finished.
+ * \param task_id The task id to be joined.
+ */
+ virtual void JoinRunningTask(int task_id);
+
+ /*!
+ * \brief Fetch the next task id.
+ * \return The next task id.
+ */
+ virtual int NextTaskId() = 0;
+
+ static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
+ TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
+};
+
+/*! \brief The task scheduler with customized methods on the python-side. */
+class PyTaskSchedulerNode : public TaskSchedulerNode {
+ public:
+ /*! \brief The function type of `Tune` method. */
+ using FTune = runtime::TypedPackedFunc<void()>;
+
+ /*!
+ * \brief The function type of `SetTaskStopped` method.
+ * \param task_id The task id to be stopped.
+ */
+ using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;
+
+ /*!
+ * \brief The function type of `IsTaskRunning` method.
+ * \param task_id The task id to be checked.
+ * \return Whether the task is running.
+ */
+ using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;
+
+ /*!
+ * \brief The function type of `JoinRunningTask` method.
+ * \param task_id The task id to be joined.
+ */
+ using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;
+
+ /*!
+ * \brief The function type of `NextTaskId` method.
+ * \return The next task id.
+ */
+ using FNextTaskId = runtime::TypedPackedFunc<int()>;
+
+ /*! \brief The packed function to the `Tune` funcion. */
+ FTune f_tune;
+ /*! \brief The packed function to the `SetTaskStopped` function. */
+ FSetTaskStopped f_set_task_stopped;
+ /*! \brief The packed function to the `IsTaskRunning` function. */
+ FIsTaskRunning f_is_task_running;
+ /*! \brief The packed function to the `JoinRunningTask` function. */
+ FJoinRunningTask f_join_running_task;
+ /*! \brief The packed function to the `NextTaskId` function. */
+ FNextTaskId f_next_task_id;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_tune` is not visited
+ // `f_set_task_stopped` is not visited
+ // `f_is_task_running` is not visited
+ // `f_join_running_task` is not visited
+ // `f_next_task_id` is not visited
+ }
+
+ void Tune() final { //
+ f_tune();
+ }
+
+ void SetTaskStopped(int task_id) final { //
+ f_set_task_stopped(task_id);
+ }
+
+ bool IsTaskRunning(int task_id) final { //
+ return f_is_task_running(task_id);
+ }
+
+ void JoinRunningTask(int task_id) final { //
+ f_join_running_task(task_id);
+ }
+
+ int NextTaskId() final { //
+ return f_next_task_id();
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode);
+};
+
+/*!
+ * \brief Managed reference to TaskSchedulerNode.
+ * \sa TaskSchedulerNode
+ */
+class TaskScheduler : public runtime::ObjectRef {
+ public:
+ /*!
+ * \brief Create a task scheduler that fetches tasks in a round-robin fashion.
+ * \param tasks The tasks to be tuned.
+ * \param builder The builder of the scheduler.
+ * \param runner The runner of the scheduler.
+ * \param database The database of the scheduler.
+ */
+ TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
+ Database database);
+ TVM_DLL static TaskScheduler PyTaskScheduler(
+ PyTaskSchedulerNode::FTune f_tune, //
+ PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
+ PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
+ PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
+ PyTaskSchedulerNode::FNextTaskId f_next_task_id);
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 87a3a49..db72328 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -36,6 +36,8 @@ class TuneContextNode : public runtime::Object {
Optional<Target> target;
/*! \brief The design space generator. */
Optional<SpaceGenerator> space_generator;
+ /*! \brief The search strategy. */
+ Optional<SearchStrategy> search_strategy;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
@@ -43,13 +45,24 @@ class TuneContextNode : public runtime::Object {
/*! \brief The number of threads to be used. */
int num_threads;
+ /*! \brief Whether the tuning task has been stopped or finished. */
+ bool is_stopped;
+ /*! \brief Packed functions to fetch the runner results asynchronously. */
+ Optional<Array<RunnerFuture>> runner_futures;
+ /*! \brief The measure candidates. */
+ Optional<Array<MeasureCandidate>> measure_candidates;
+
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
+ v->Visit("search_strategy", &search_strategy);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
+ v->Visit("is_stopped", &is_stopped);
+ v->Visit("runner_futures", &runner_futures);
+ v->Visit("measure_candidates", &measure_candidates);
}
static constexpr const char* _type_key = "meta_schedule.TuneContext";
@@ -67,6 +80,7 @@ class TuneContext : public runtime::ObjectRef {
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param space_generator The design space generator.
+ * \param search_strategy The search strategy.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
@@ -74,6 +88,7 @@ class TuneContext : public runtime::ObjectRef {
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
+ Optional<SearchStrategy> search_strategy, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py
index dcd430d..320647b 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/database/__init__.py
@@ -18,5 +18,5 @@
The tvm.meta_schedule.database package.
The database that stores serialized tuning records and workloads
"""
-from .database import Database, PyDatabase, TuningRecord
+from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py
index 40f21da..609baa2 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/search_strategy/__init__.py
@@ -14,7 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Search Strategy"""
+"""
+The tvm.meta_schedule.search_strategy package.
+Meta Schedule search strategy utilizes the design spaces given
+to generate measure candidates.
+"""
from .search_strategy import SearchStrategy, PySearchStrategy
from .replay_trace import ReplayTrace
diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/task_scheduler/__init__.py
similarity index 73%
copy from python/tvm/meta_schedule/database/__init__.py
copy to python/tvm/meta_schedule/task_scheduler/__init__.py
index dcd430d..dbfe962 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/task_scheduler/__init__.py
@@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""
-The tvm.meta_schedule.database package.
-The database that stores serialized tuning records and workloads
+The tvm.meta_schedule.task_scheduler package.
+Meta Schedule task scheduler that manage the task scheduling
+for measure candidates generation and measurement, then save
+records to the database.
"""
-from .database import Database, PyDatabase, TuningRecord
-from .json_database import JSONDatabase
+from .task_scheduler import TaskScheduler, PyTaskScheduler
+from .round_robin import RoundRobin
diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py
new file mode 100644
index 0000000..391011b
--- /dev/null
+++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Round Robin Task Scheduler"""
+
+from typing import List, TYPE_CHECKING
+
+from tvm._ffi import register_object
+
+from ..builder import Builder
+from ..runner import Runner
+from ..database import Database
+from .task_scheduler import TaskScheduler
+
+from .. import _ffi_api
+
+if TYPE_CHECKING:
+ from ..tune_context import TuneContext
+
+
+@register_object("meta_schedule.RoundRobin")
+class RoundRobin(TaskScheduler):
+ """Round Robin Task Scheduler"""
+
+ def __init__(
+ self,
+ tasks: List["TuneContext"],
+ builder: Builder,
+ runner: Runner,
+ database: Database,
+ ) -> None:
+ """Constructor.
+
+ Parameters
+ ----------
+ tasks : List[TuneContext]
+ List of tasks to schedule.
+ builder : Builder
+ The builder.
+ runner : Runner
+ The runner.
+ database : Database
+ The database.
+ """
+ self.__init_handle_by_constructor__(
+ _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member
+ tasks,
+ builder,
+ runner,
+ database,
+ )
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
new file mode 100644
index 0000000..b8dcfd9
--- /dev/null
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Auto-tuning Task Scheduler"""
+from tvm._ffi import register_object
+from tvm.runtime import Object
+
+from .. import _ffi_api
+
+
+@register_object("meta_schedule.TaskScheduler")
+class TaskScheduler(Object):
+ """The abstract task scheduler interface."""
+
+ def tune(self) -> None:
+ """Auto-tuning."""
+ _ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member
+
+ def _set_task_stopped(self, task_id: int) -> None:
+ """Set specific task to be stopped.
+
+ Parameters
+ ----------
+ task_id : int
+ The task id to be stopped.
+ """
+ _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member
+
+ def _is_task_running(self, task_id: int) -> bool:
+ """Check whether the task is running.
+
+ Parameters
+ ----------
+ task_id : int
+ The task id to be checked.
+
+ Returns
+ -------
+ bool
+ Whether the task is running.
+ """
+ return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member
+
+ def _join_running_task(self, task_id: int) -> None:
+ """Wait until the task is finished.
+
+ Parameters
+ ----------
+ task_id : int
+ The task id to be joined.
+ """
+ _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member
+
+ def _next_task_id(self) -> int:
+ """Fetch the next task id.
+
+ Returns
+ -------
+ int
+ The next task id.
+ """
+ return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyTaskScheduler")
+class PyTaskScheduler(TaskScheduler):
+ """An abstract task scheduler with customized methods on the python-side."""
+
+ def __init__(self):
+ """Constructor."""
+
+ def f_tune() -> None:
+ self.tune()
+
+ def f_set_task_stopped(task_id: int) -> None:
+ self._set_task_stopped(task_id)
+
+ def f_is_task_running(task_id: int) -> bool:
+ return self._is_task_running(task_id)
+
+ def f_join_running_task(task_id: int) -> None:
+ self._join_running_task(task_id)
+
+ def f_next_task_id() -> int:
+ return self._next_task_id()
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member
+ f_tune,
+ f_set_task_stopped,
+ f_is_task_running,
+ f_join_running_task,
+ f_next_task_id,
+ )
+
+ def tune(self) -> None:
+ raise NotImplementedError()
+
+ def _set_task_stopped(self, task_id: int) -> None:
+ _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member
+
+ def _is_task_running(self, task_id: int) -> bool:
+ return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member
+
+ def _join_running_task(self, task_id: int) -> None:
+ _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member
+
+ def _next_task_id(self) -> int:
+ return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 9c41b4d..0f3cfac 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -28,6 +28,7 @@ from . import _ffi_api
if TYPE_CHECKING:
from .space_generator import SpaceGenerator
+ from .search_strategy import SearchStrategy
@register_object("meta_schedule.TuneContext")
@@ -45,6 +46,10 @@ class TuneContext(Object):
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
+ space_generator : Optional[SpaceGenerator] = None
+ The design space generator.
+ search_strategy : Optional[SearchStrategy] = None
+ The search strategy.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
@@ -63,6 +68,8 @@ class TuneContext(Object):
mod: Optional[IRModule]
target: Optional[Target]
+ space_generator: "SpaceGenerator"
+ search_strategy: "SearchStrategy"
task_name: Optional[str]
rand_state: int
num_threads: int
@@ -72,6 +79,7 @@ class TuneContext(Object):
mod: Optional[IRModule] = None,
target: Optional[Target] = None,
space_generator: Optional["SpaceGenerator"] = None,
+ search_strategy: Optional["SearchStrategy"] = None,
task_name: Optional[str] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
@@ -86,6 +94,8 @@ class TuneContext(Object):
The target to be optimized for.
space_generator : Optional[SpaceGenerator] = None
The design space generator.
+ search_strategy : Optional[SearchStrategy] = None
+ The search strategy.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
@@ -102,6 +112,7 @@ class TuneContext(Object):
mod,
target,
space_generator,
+ search_strategy,
task_name,
rand_state,
num_threads,
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index 5f53699..bf2ef17 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -21,9 +21,10 @@ import shutil
from typing import Any, Callable, List, Optional, Union
import psutil
+import tvm
from tvm._ffi import get_global_func, register_func
from tvm.error import TVMError
-from tvm.ir import Array, Map
+from tvm.ir import Array, Map, IRModule
from tvm.rpc import RPCSession
from tvm.runtime import PackedFunc, String
from tvm.tir import FloatImm, IntImm
@@ -183,3 +184,24 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]:
for json_str in map(str.strip, json_strs)
if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//"))
]
+
+
+def structural_hash(mod: IRModule) -> str:
+ """Get the structural hash of a module.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The module to be hashed.
+
+ Returns
+ -------
+ result : str
+ The structural hash of the module.
+ """
+ shash = tvm.ir.structural_hash(mod)
+ if shash < 0:
+ # Workaround because `structural_hash` returns a size_t, i.e., unsigned integer
+ # but ffi can't handle unsigned integers properly so it's parsed into a negative number
+ shash += 1 << 64
+ return str(shash)
diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc
new file mode 100644
index 0000000..a529f23
--- /dev/null
+++ b/src/meta_schedule/task_scheduler/round_robin.cc
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief The round-robin style task scheduler. */
+class RoundRobinNode final : public TaskSchedulerNode {
+ public:
+ /*! \brief The current task id processed. */
+ int task_id = -1;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TaskSchedulerNode::VisitAttrs(v);
+ v->Visit("task_id", &task_id);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.RoundRobin";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RoundRobinNode, TaskSchedulerNode);
+
+ protected:
+ int NextTaskId() final {
+ int n_tasks = this->tasks.size();
+ for (int i = 0; i < n_tasks; ++i) {
+ task_id = (task_id + 1) % n_tasks;
+ TuneContext task = tasks[task_id];
+ if (!task->is_stopped) {
+ if (IsTaskRunning(task_id)) {
+ JoinRunningTask(task_id);
+ }
+ return task_id;
+ }
+ }
+ return -1;
+ }
+};
+
+TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
+ Database database) {
+ ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
+ n->tasks = tasks;
+ n->builder = builder;
+ n->runner = runner;
+ n->database = database;
+ n->task_id = -1;
+ return TaskScheduler(n);
+}
+
+TVM_REGISTER_NODE_TYPE(RoundRobinNode);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin")
+ .set_body_typed(TaskScheduler::RoundRobin);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
new file mode 100644
index 0000000..cf0af3d
--- /dev/null
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*!
+ * \brief Send the measure candidates to builder.
+ * \param builder The builder to send the candidates to.
+ * \param context The tuning context.
+ * \param candidates The measure candidates.
+ * \return An array of the builder results.
+ */
+Array<BuilderResult> SendToBuilder(const Builder& builder, //
+ const TuneContext& context,
+ const Array<MeasureCandidate>& candidates) {
+ Target target = context->target.value();
+ Array<BuilderInput> inputs;
+ inputs.reserve(candidates.size());
+ for (const MeasureCandidate& candidate : candidates) {
+ inputs.push_back(BuilderInput(candidate->sch->mod(), target));
+ }
+ return builder->Build(inputs);
+}
+
+/*!
+ * \brief Send the built measure candidates to runner.
+ * \param runner The runner to send the candidates to.
+ * \param context The tuning context.
+ * \param candidates The mesure candidates.
+ * \param builder_results The builder results.
+ * \return An array of the runner results.
+ */
+Array<RunnerFuture> SendToRunner(const Runner& runner, //
+ const TuneContext& context,
+ const Array<MeasureCandidate>& candidates,
+ const Array<BuilderResult>& builder_results) {
+ Target target = context->target.value();
+ ICHECK_EQ(candidates.size(), builder_results.size());
+ int n = candidates.size();
+ int n_build_errors = 0;
+ Array<RunnerInput> inputs;
+ inputs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ const MeasureCandidate& candidate = candidates[i];
+ const BuilderResult& builder_result = builder_results[i];
+ if (builder_result->error_msg.defined()) {
+ ++n_build_errors;
+ continue;
+ }
+ inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(),
+ /*device_type=*/target->kind->name,
+ /*args_info=*/candidate->args_info));
+ }
+ Array<RunnerFuture> futures = runner->Run(inputs);
+ if (n_build_errors == 0) {
+ return futures;
+ }
+ Array<RunnerFuture> results;
+ results.reserve(n);
+ for (int i = 0, j = 0; i < n; ++i) {
+ const BuilderResult& builder_result = builder_results[i];
+ if (builder_result->error_msg.defined()) {
+ results.push_back(RunnerFuture(
+ /*f_done=*/[]() -> bool { return true; },
+ /*f_result=*/
+ [msg = builder_result->error_msg]() -> RunnerResult {
+ return RunnerResult(NullOpt, msg);
+ }));
+ } else {
+ results.push_back(futures[j++]);
+ }
+ }
+ return results;
+}
+
+void TaskSchedulerNode::Tune() {
+ for (const TuneContext& task : this->tasks) {
+ CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
+ CHECK(task->space_generator.defined())
+ << "ValueError: Require `context.space_generator`, but it is not defined";
+ CHECK(task->search_strategy.defined())
+ << "ValueError: Require `context.search_strategy`, but it is not defined";
+ IRModule mod = task->mod.value();
+ SpaceGenerator space = task->space_generator.value();
+ SearchStrategy strategy = task->search_strategy.value();
+ space->InitializeWithTuneContext(task);
+ strategy->InitializeWithTuneContext(task);
+ strategy->PreTuning(space->GenerateDesignSpace(mod));
+ }
+
+ int running_tasks = tasks.size();
+ while (running_tasks > 0) {
+ for (int task_id; (task_id = NextTaskId()) != -1;) {
+ TuneContext task = tasks[task_id];
+ ICHECK(!task->is_stopped);
+ ICHECK(!task->runner_futures.defined());
+ SearchStrategy strategy = task->search_strategy.value();
+ if (task->measure_candidates = strategy->GenerateMeasureCandidates()) {
+ Array<BuilderResult> builder_results =
+ SendToBuilder(this->builder, task, task->measure_candidates.value());
+ task->runner_futures =
+ SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results);
+ } else {
+ SetTaskStopped(task_id);
+ --running_tasks;
+ }
+ }
+ int n_tasks = this->tasks.size();
+ for (int task_id = 0; task_id < n_tasks; ++task_id)
+ if (IsTaskRunning(task_id)) {
+ TuneContext task = tasks[task_id];
+ this->JoinRunningTask(task_id);
+ task->search_strategy.value()->PostTuning();
+ }
+ }
+}
+
+void TaskSchedulerNode::SetTaskStopped(int task_id) {
+ TuneContext task = tasks[task_id];
+ ICHECK(!task->is_stopped);
+ task->is_stopped = true;
+}
+
+bool TaskSchedulerNode::IsTaskRunning(int task_id) {
+ TuneContext task = tasks[task_id];
+ if (task->is_stopped || !task->runner_futures.defined()) {
+ return false;
+ }
+ for (const RunnerFuture future : task->runner_futures.value()) {
+ if (!future->Done()) {
+ return true;
+ }
+ }
+ this->JoinRunningTask(task_id);
+ return false;
+}
+
+void TaskSchedulerNode::JoinRunningTask(int task_id) {
+ TuneContext task = tasks[task_id];
+ ICHECK(task->runner_futures.defined());
+ Array<RunnerFuture> futures = task->runner_futures.value();
+ int n = futures.size();
+ Array<RunnerResult> results;
+ results.reserve(n);
+ for (const RunnerFuture future : task->runner_futures.value()) {
+ results.push_back(future->Result());
+ }
+ task->search_strategy.value()->NotifyRunnerResults(results);
+ task->runner_futures = NullOpt;
+ // Add to database
+ ICHECK(task->measure_candidates.defined());
+ ICHECK(results.size() == task->measure_candidates.value().size());
+ int index = 0;
+ for (const RunnerResult& result : results) {
+ if (!result->error_msg.defined() && result->run_secs.defined()) {
+ Optional<tir::Trace> trace = task->measure_candidates.value()[index]->sch->trace();
+ ICHECK(trace.defined());
+ this->database->CommitTuningRecord(TuningRecord(
+ /*trace=*/trace.value(),
+ /*run_secs=*/result->run_secs.value(),
+ /*workload=*/this->database->CommitWorkload(task->mod.value()),
+ /*target=*/task->target.value(),
+ /*args_info=*/task->measure_candidates.value()[index]->args_info));
+ }
+ index++;
+ }
+}
+
+TaskScheduler TaskScheduler::PyTaskScheduler(
+ PyTaskSchedulerNode::FTune f_tune, //
+ PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
+ PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
+ PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
+ PyTaskSchedulerNode::FNextTaskId f_next_task_id) {
+ ObjectPtr<PyTaskSchedulerNode> n = make_object<PyTaskSchedulerNode>();
+ n->f_tune = f_tune;
+ n->f_set_task_stopped = f_set_task_stopped;
+ n->f_is_task_running = f_is_task_running;
+ n->f_join_running_task = f_join_running_task;
+ n->f_next_task_id = f_next_task_id;
+ return TaskScheduler(n);
+}
+
+TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode);
+TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode);
+TVM_REGISTER_GLOBAL("tvm.task.TaskSchedulerPyTaskScheduler")
+ .set_body_typed(TaskScheduler::PyTaskScheduler);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerSetTaskStopped")
+ .set_body_method<TaskScheduler>(&TaskSchedulerNode::SetTaskStopped);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerIsTaskRunning")
+ .set_body_method<TaskScheduler>(&TaskSchedulerNode::IsTaskRunning);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune")
+ .set_body_method<TaskScheduler>(&TaskSchedulerNode::Tune);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask")
+ .set_body_method<TaskScheduler>(&TaskSchedulerNode::JoinRunningTask);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId")
+ .set_body_method<TaskScheduler>(&TaskSchedulerNode::NextTaskId);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc
index ad82b6f..9fc9272 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -37,6 +37,7 @@ namespace meta_schedule {
TuneContext::TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
+ Optional<SearchStrategy> search_strategy, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) {
@@ -44,12 +45,16 @@ TuneContext::TuneContext(Optional<IRModule> mod,
n->mod = mod;
n->target = target;
n->space_generator = space_generator;
+ n->search_strategy = search_strategy;
n->task_name = task_name;
if (rand_state == -1) {
rand_state = std::random_device()();
}
support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state);
n->num_threads = num_threads;
+ n->is_stopped = false;
+ n->runner_futures = NullOpt;
+ n->measure_candidates = NullOpt;
data_ = std::move(n);
}
@@ -59,10 +64,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
.set_body_typed([](Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
+ Optional<SearchStrategy> search_strategy, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) -> TuneContext {
- return TuneContext(mod, target, space_generator, task_name, rand_state, num_threads);
+ return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state,
+ num_threads);
});
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 30294b8..83e65a5 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -26,6 +26,7 @@
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
+#include <tvm/meta_schedule/task_scheduler.h>
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/node.h>
#include <tvm/node/serialization.h>
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py
new file mode 100644
index 0000000..bdd504c
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test Meta Schedule Task Scheduler """
+
+from typing import List
+
+import sys
+import random
+
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.ir import IRModule
+from tvm.tir import Schedule
+from tvm.meta_schedule import TuneContext
+from tvm.meta_schedule.space_generator import ScheduleFn
+from tvm.meta_schedule.search_strategy import ReplayTrace
+from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
+from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
+from tvm.meta_schedule.task_scheduler import RoundRobin
+from tvm.meta_schedule.utils import structural_hash
+
+
+# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
+
+
+@tvm.script.tir
+class MatmulModule:
+ def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument
+ tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+ A = tir.match_buffer(a, (1024, 1024), "float32")
+ B = tir.match_buffer(b, (1024, 1024), "float32")
+ C = tir.match_buffer(c, (1024, 1024), "float32")
+ with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.script.tir
+class MatmulReluModule:
+ def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument
+ tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+ A = tir.match_buffer(a, (1024, 1024), "float32")
+ B = tir.match_buffer(b, (1024, 1024), "float32")
+ D = tir.match_buffer(d, (1024, 1024), "float32")
+ C = tir.alloc_buffer((1024, 1024), "float32")
+ with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+ with tir.block([1024, 1024], "relu") as [vi, vj]:
+ D[vi, vj] = tir.max(C[vi, vj], 0.0)
+
+
+@tvm.script.tir
+class BatchMatmulModule:
+ def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument
+ tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+ A = tir.match_buffer(a, [16, 128, 128])
+ B = tir.match_buffer(b, [16, 128, 128])
+ C = tir.match_buffer(c, [16, 128, 128])
+ with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "matmul") as [vn, vi, vj, vk]:
+ with tir.init():
+ C[vn, vi, vj] = 0.0
+ C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
+
+
+# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks
+
+
+def _schedule_matmul(sch: Schedule):
+ block = sch.get_block("matmul")
+ i, j, k = sch.get_loops(block=block)
+ # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
+ i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
+ j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
+ k_0, k_1 = sch.split(loop=k, factors=[32, 32])
+ sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
+
+
+def _schedule_batch_matmul(sch: Schedule):
+ block = sch.get_block("matmul")
+ i, j, k, t = sch.get_loops(block=block)
+ # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
+ i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2])
+ j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2])
+ k_0, k_1 = sch.split(loop=k, factors=[32, 32])
+ t_0, t_1 = sch.split(loop=t, factors=[2, 512])
+ sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
+
+
+class DummyRunnerFuture(RunnerFuture):
+ def done(self) -> bool:
+ return True
+
+ def result(self) -> RunnerResult:
+ return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None)
+
+
+class DummyBuilder(PyBuilder):
+ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
+ return [BuilderResult("test_path", None) for _ in build_inputs]
+
+
+class DummyRunner(PyRunner):
+ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+ return [DummyRunnerFuture() for _ in runner_inputs]
+
+
+class DummyDatabase(PyDatabase):
+ def __init__(self):
+ super().__init__()
+ self.records = []
+ self.workload_reg = []
+
+ def commit_tuning_record(self, record: TuningRecord) -> None:
+ self.records.append(record)
+
+ def commit_workload(self, mod: IRModule) -> Workload:
+ for workload in self.workload_reg:
+ if tvm.ir.structural_equal(workload.mod, mod):
+ return workload
+ workload = Workload(mod)
+ self.workload_reg.append(workload)
+ return workload
+
+ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
+ return list(
+ filter(
+ lambda x: x.workload == workload,
+ sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
+ )
+ )[: int(top_k)]
+
+ def __len__(self) -> int:
+ return len(self.records)
+
+ def print_results(self) -> None:
+ print("\n".join([str(r) for r in self.records]))
+
+
+def test_meta_schedule_task_scheduler_single():
+ num_trials_per_iter = 3
+ num_trials_total = 10
+ sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
+ replay = ReplayTrace(num_trials_per_iter, num_trials_total)
+ task = TuneContext(
+ MatmulModule(),
+ target=tvm.target.Target("llvm"),
+ space_generator=sch_fn,
+ search_strategy=replay,
+ task_name="Test",
+ rand_state=42,
+ )
+ database = DummyDatabase()
+ round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database)
+ round_robin.tune()
+ assert len(database) == num_trials_total
+
+
+def test_meta_schedule_task_scheduler_multiple():
+ num_trials_per_iter = 6
+ num_trials_total = 101
+ tasks = [
+ TuneContext(
+ MatmulModule(),
+ target=tvm.target.Target("llvm"),
+ space_generator=ScheduleFn(sch_fn=_schedule_matmul),
+ search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
+ task_name="Matmul",
+ rand_state=42,
+ ),
+ TuneContext(
+ MatmulReluModule(),
+ target=tvm.target.Target("llvm"),
+ space_generator=ScheduleFn(sch_fn=_schedule_matmul),
+ search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
+ task_name="MatmulRelu",
+ rand_state=0xDEADBEEF,
+ ),
+ TuneContext(
+ BatchMatmulModule(),
+ target=tvm.target.Target("llvm"),
+ space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
+ search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
+ task_name="BatchMatmul",
+ rand_state=0x114514,
+ ),
+ ]
+ database = DummyDatabase()
+ round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database)
+ round_robin.tune()
+ assert len(database) == num_trials_total * len(tasks)
+ print(database.workload_reg)
+ for task in tasks:
+ assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))