You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/07 21:17:50 UTC
[tvm] branch main updated: [MetaSchedule][Refactor] Clarify Integration Logic (#10927)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5f1f8f3421 [MetaSchedule][Refactor] Clarify Integration Logic (#10927)
5f1f8f3421 is described below
commit 5f1f8f34212b462610881c030bacb0e6ba5802ec
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Apr 7 14:17:41 2022 -0700
[MetaSchedule][Refactor] Clarify Integration Logic (#10927)
---
include/tvm/meta_schedule/apply_history_best.h | 83 +++++++
include/tvm/meta_schedule/extracted_task.h | 68 ++++++
include/tvm/meta_schedule/integration.h | 190 ----------------
python/tvm/meta_schedule/__init__.py | 4 +-
python/tvm/meta_schedule/apply_history_best.py | 100 +++++++++
python/tvm/meta_schedule/extracted_task.py | 66 ++++++
python/tvm/meta_schedule/integration.py | 247 ---------------------
python/tvm/meta_schedule/relay_integration.py | 91 ++++++++
python/tvm/meta_schedule/testing/relay_workload.py | 2 +-
.../testing/tune_relay_meta_schedule.py | 5 +-
python/tvm/meta_schedule/testing/utils.py | 32 +--
python/tvm/meta_schedule/tune.py | 15 +-
.../{integration.cc => apply_history_best.cc} | 90 +++-----
src/meta_schedule/extracted_task.cc | 43 ++++
src/meta_schedule/utils.h | 1 +
src/relay/backend/task_extraction.cc | 7 +-
src/relay/backend/te_compiler_cache.cc | 21 +-
.../unittest/test_meta_schedule_integration.py | 22 +-
.../unittest/test_meta_schedule_multi_anchor.py | 3 +-
.../unittest/test_meta_schedule_tune_relay.py | 24 +-
20 files changed, 544 insertions(+), 570 deletions(-)
diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h
new file mode 100644
index 0000000000..9d6f46dd6c
--- /dev/null
+++ b/include/tvm/meta_schedule/apply_history_best.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
+#define TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
+
+#include <tvm/meta_schedule/database.h>
+#include <tvm/target/target.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+/*!
+ * \brief An integration context that allows application of historically best records from a
+ * database
+ */
+class ApplyHistoryBestNode : public runtime::Object {
+ public:
+ /*! \brief The database to be queried from */
+ Database database{nullptr};
+
+ void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); }
+ /*!
+ * \brief Query the best entry from the database
+ * \param task_name The name of the task to be queried
+ * \param mod The module to be queried
+ * \param target The target to be queried
+ * \param dispatched The IRs after dispatch
+ */
+ Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
+ Optional<Array<IRModule>> dispatched);
+
+ static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, runtime::Object);
+};
+
+/*!
+ * \brief Managed reference to ApplyHistoryBestNode
+ * \sa ApplyHistoryBestNode
+ */
+class ApplyHistoryBest : public runtime::ObjectRef {
+ public:
+ /*!
+ * \brief Constructor
+ * \param database The database to be queried from
+ */
+ explicit ApplyHistoryBest(Database database);
+ /*!
+ * \brief The current ApplyHistoryBest in the context
+ * \return The ApplyHistoryBest in the current scope.
+ */
+ static Optional<ApplyHistoryBest> Current();
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, runtime::ObjectRef,
+ ApplyHistoryBestNode);
+
+ protected:
+ friend class ApplyHistoryBestInternal;
+ /*! \brief Entering the scope of the context manager */
+ void EnterWithScope();
+ /*! \brief Exiting the scope of the context manager */
+ void ExitWithScope();
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h
new file mode 100644
index 0000000000..c6613427fd
--- /dev/null
+++ b/include/tvm/meta_schedule/extracted_task.h
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_EXTRACTED_TASK_H_
+#define TVM_META_SCHEDULE_EXTRACTED_TASK_H_
+
+#include <tvm/target/target.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief A tuning task extracted from the high-level IR */
+class ExtractedTaskNode : public runtime::Object {
+ public:
+ /*! \brief The name of the task extracted */
+ String task_name;
+ /*! \brief The high-level IR */
+ IRModule mod;
+ /*! \brief Target */
+ Target target;
+ /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */
+ Array<IRModule> dispatched;
+ /*! \brief Weight of the task */
+ int weight;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("task_name", &task_name);
+ v->Visit("mod", &mod);
+ v->Visit("target", &target);
+ v->Visit("dispatched", &dispatched);
+ v->Visit("weight", &weight);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.ExtractedTask";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object);
+};
+
+/*!
+ * \brief Managed reference to ExtractedTaskNode
+ * \sa ExtractedTaskNode
+ */
+class ExtractedTask : public runtime::ObjectRef {
+ public:
+ explicit ExtractedTask(String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
+ int weight);
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef,
+ ExtractedTaskNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_EXTRACTED_TASK_H_
diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h
deleted file mode 100644
index b231913f2f..0000000000
--- a/include/tvm/meta_schedule/integration.h
+++ /dev/null
@@ -1,190 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_META_SCHEDULE_INTEGRATION_H_
-#define TVM_META_SCHEDULE_INTEGRATION_H_
-
-#include <tvm/meta_schedule/database.h>
-#include <tvm/support/with.h>
-#include <tvm/target/target.h>
-
-#include <unordered_set>
-
-namespace tvm {
-namespace meta_schedule {
-
-/**************** ExtractedTask ****************/
-
-/*!
- * \brief A tuning task extracted from the high-level IR
- */
-class ExtractedTaskNode : public runtime::Object {
- public:
- /*! \brief The name of the task extracted */
- String task_name;
- /*! \brief The high-level IR */
- IRModule mod;
- /*! \brief Target */
- Target target;
- /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */
- Array<IRModule> dispatched;
- /*! \brief Weight of the task */
- int weight;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("task_name", &task_name);
- v->Visit("mod", &mod);
- v->Visit("target", &target);
- v->Visit("dispatched", &dispatched);
- v->Visit("weight", &weight);
- }
-
- static constexpr const char* _type_key = "meta_schedule.ExtractedTask";
- TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object);
-};
-
-/*!
- * \brief Managed reference to ExtractedTaskNode
- * \sa ExtractedTaskNode
- */
-class ExtractedTask : public runtime::ObjectRef {
- public:
- /*!
- * \brief Constructor. The name of the task extracted
- * \brief The high-level IR
- * \brief A list of low-level IRs that the high-level IR could potentially dispatch to
- */
- explicit ExtractedTask(String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
- int weight);
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef,
- ExtractedTaskNode);
-};
-
-/**************** MetaScheduleContext ****************/
-
-/*!
- * \brief A context manager interface for the integration
- */
-class MetaScheduleContextNode : public runtime::Object {
- public:
- /*! \brief Default destructor */
- virtual ~MetaScheduleContextNode() = default;
- /*!
- * \brief The entry point of the integration
- * \param task_name The name of the task
- * \param mod The high-level IR
- * \param target Target info
- * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to.
- * NullOpt means the dispatch needs to be done in the context.
- * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
- * under IRModule for more general future use. NullOpt is returned
- * if there is no feedback hint.
- */
- virtual Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
- Optional<Array<IRModule>> dispatched) = 0;
-
- static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext";
- TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object);
-};
-
-/*!
- * \brief Managed reference to MetaScheduleContextNode
- * \sa MetaScheduleContextNode
- */
-class MetaScheduleContext : public runtime::ObjectRef {
- friend class MetaScheduleContextInternal;
- friend class With<MetaScheduleContext>;
-
- public:
- /*! \brief Default destructor */
- virtual ~MetaScheduleContext() = default;
- /*!
- * \brief The context manager in the current scope
- * \return The MetaScheduleContext in the current scope. NullOpt if it's currently not under any
- * MetaScheduleContext.
- */
- static Optional<MetaScheduleContext> Current();
- /*!
- * \brief The entry point of the integration workflow. The compilation process of the high-level
- * IR should call this method for task extraction and for feedback hints
- * \param task_name The name of the task
- * \param mod The high-level IR
- * \param target Target info
- * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to
- * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
- * under IRModule for more general future use. NullOpt is returned
- * if there is no feedback hint
- */
- static Optional<IRModule> QueryInsideWithScope(runtime::String task_name, IRModule mod,
- Target target,
- Optional<Array<IRModule>> dispatched);
-
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef,
- MetaScheduleContextNode);
-
- protected:
- /*! \brief Default constructor */
- MetaScheduleContext() = default;
- /*! \brief Entering the scope of the context manager */
- void EnterWithScope();
- /*! \brief Exiting the scope of the context manager */
- void ExitWithScope();
-};
-
-/**************** ApplyHistoryBest ****************/
-
-/*!
- * \brief An integration context that allows application of historically best records from a
- * database
- */
-class ApplyHistoryBestNode : public MetaScheduleContextNode {
- public:
- /*! \brief The database to be queried from */
- Database database{nullptr};
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("database", &database); //
- }
-
- // Inherited from base class
- Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
- Optional<Array<IRModule>> dispatched) final;
-
- static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
- TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode);
-};
-
-/*!
- * \brief Managed reference to ApplyHistoryBestNode
- * \sa ApplyHistoryBestNode
- */
-class ApplyHistoryBest : public MetaScheduleContext {
- public:
- /*!
- * \brief Constructor
- * \param database The database to be queried from
- */
- explicit ApplyHistoryBest(Database database);
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, MetaScheduleContext,
- ApplyHistoryBestNode);
-};
-
-} // namespace meta_schedule
-} // namespace tvm
-
-#endif // TVM_META_SCHEDULE_INTEGRATION_H_
diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index 3612bb81a6..466c5e3e66 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -21,7 +21,6 @@ from . import (
cost_model,
database,
feature_extractor,
- integration,
mutator,
postproc,
runner,
@@ -29,6 +28,9 @@ from . import (
search_strategy,
space_generator,
)
+from .apply_history_best import ApplyHistoryBest
+from .extracted_task import ExtractedTask
+from .relay_integration import extract_task_from_relay
from .search_strategy import MeasureCandidate
from .tune import (
EvolutionarySearchConfig,
diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py
new file mode 100644
index 0000000000..5e1e40bd15
--- /dev/null
+++ b/python/tvm/meta_schedule/apply_history_best.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A context manager that injects the best tuning record in the database into compilation"""
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.target import Target
+
+from . import _ffi_api
+from .database import Database
+
+
+@register_object("meta_schedule.ApplyHistoryBest")
+class ApplyHistoryBest(Object):
+ """An integration context that allows application of historically best records from a database
+
+ Parameters
+ ----------
+ database : Database
+ The database to be queried from
+ """
+
+ database: Database
+
+ def __init__(
+ self,
+ database: Database,
+ ) -> None:
+ self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
+
+ def query(
+ self,
+ task_name: str,
+ mod: IRModule,
+ target: Target,
+ dispatched: Optional[List[IRModule]],
+ ) -> Union[IRModule, None]:
+ """The entry point of the integration
+
+ Parameters
+ ----------
+ task_name : str
+ The name of the task extracted
+ mod : IRModule
+ The high-level IR
+ target: Target
+ Target Info
+ dispatched : Optional[List[IRModule]]
+ A list of low-level IRs that the high-level IR could potentially dispatch to
+
+ Returns
+ -------
+ result : IRModule or None
+ Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
+ more general future use. None is returned if there is no feedback hint.
+ """
+ return _ffi_api.ApplyHistoryBestQuery( # type: ignore # pylint: disable=no-member
+ self,
+ task_name,
+ mod,
+ target,
+ dispatched,
+ )
+
+ @staticmethod
+ def current() -> Optional["ApplyHistoryBest"]:
+ """The context manager in the current scope
+
+ Returns
+ -------
+ ctx : Optional[ApplyHistoryBest]
+ The ApplyHistoryBest context manager in the current scope.
+ None if it's currently not under any ApplyHistoryBest context.
+ """
+ return _ffi_api.ApplyHistoryBestCurrent() # type: ignore # pylint: disable=no-member
+
+ def __enter__(self) -> "ApplyHistoryBest":
+ """Entering the scope of the context manager"""
+ _ffi_api.ApplyHistoryBestEnterScope(self) # type: ignore # pylint: disable=no-member
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None:
+ """Exiting the scope of the context manager"""
+ _ffi_api.ApplyHistoryBestExitScope(self) # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/meta_schedule/extracted_task.py b/python/tvm/meta_schedule/extracted_task.py
new file mode 100644
index 0000000000..b69a38ef6d
--- /dev/null
+++ b/python/tvm/meta_schedule/extracted_task.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Extracted tasks from high-level IR."""
+from typing import List
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.target import Target
+
+from . import _ffi_api
+
+
+@register_object("meta_schedule.ExtractedTask")
+class ExtractedTask(Object):
+ """A tuning task extracted from the high-level IR
+
+ Parameters
+ ----------
+ task_name : str
+ The name of the task extracted
+ mod : IRModule
+ The high-level IR
+ target: Target
+ Target information
+ dispatched : List[IRModule]
+ A list of low-level IRs that the high-level IR could potentially dispatch to
+ weight : int
+ The weight of the task
+ """
+
+ task_name: str
+ mod: IRModule
+ dispatched: List[IRModule]
+ weight: int
+
+ def __init__(
+ self,
+ task_name: str,
+ mod: IRModule,
+ target: Target,
+ dispatched: List[IRModule],
+ weight: int,
+ ) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member
+ task_name,
+ mod,
+ target,
+ dispatched,
+ weight,
+ )
diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py
deleted file mode 100644
index db6771feca..0000000000
--- a/python/tvm/meta_schedule/integration.py
+++ /dev/null
@@ -1,247 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Meta schedule integration with high-level IR"""
-from typing import Dict, List, Optional, Union
-
-import numpy as np # type: ignore
-import tvm.runtime.ndarray as nd
-from tvm._ffi import get_global_func, register_object
-from tvm.ir import IRModule, transform
-from tvm.relay import Any
-from tvm.relay import Function as RelayFunc
-from tvm.runtime import NDArray, Object
-from tvm.target import Target
-
-from . import _ffi_api
-from .database import Database
-from .utils import autotvm_silencer
-
-
-@register_object("meta_schedule.ExtractedTask")
-class ExtractedTask(Object):
- """A tuning task extracted from the high-level IR
-
- Parameters
- ----------
- task_name : str
- The name of the task extracted
- mod : IRModule
- The high-level IR
- target: Target
- Target information
- dispatched : List[IRModule]
- A list of low-level IRs that the high-level IR could potentially dispatch to
- weight : int
- The weight of the task
- """
-
- task_name: str
- mod: IRModule
- dispatched: List[IRModule]
- weight: int
-
- def __init__(
- self,
- task_name: str,
- mod: IRModule,
- target: Target,
- dispatched: List[IRModule],
- weight: int,
- ) -> None:
- self.__init_handle_by_constructor__(
- _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member
- task_name,
- mod,
- target,
- dispatched,
- weight,
- )
-
-
-@register_object("meta_schedule.MetaScheduleContext")
-class MetaScheduleContext(Object):
- """A context manager interface for the integration"""
-
- def query(
- self,
- task_name: str,
- mod: IRModule,
- target: Target,
- dispatched: Optional[List[IRModule]],
- ) -> Union[IRModule, None]:
- """The entry point of the integration
-
- Parameters
- ----------
- task_name : str
- The name of the task extracted
- mod : IRModule
- The high-level IR
- target: Target
- Target Info
- dispatched : Optional[List[IRModule]]
- A list of low-level IRs that the high-level IR could potentially dispatch to
-
- Returns
- -------
- result : IRModule or None
- Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
- more general future use. None is returned if there is no feedback hint.
- """
- return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member
- self,
- task_name,
- mod,
- target,
- dispatched,
- )
-
- @staticmethod
- def current() -> Optional["MetaScheduleContext"]:
- """The context manager in the current scope
-
- Returns
- -------
- ctx : Optional[MetaScheduleContext]
- The MetaScheduleContext in the current scope.
- NullOpt if it's currently not under any MetaScheduleContext.
- """
- return _ffi_api.MetaScheduleContextCurrent() # type: ignore # pylint: disable=no-member
-
- @staticmethod
- def query_inside_with_scope(
- task_name: str,
- mod: IRModule,
- target: Target,
- dispatched: Optional[List[IRModule]],
- ) -> Union[IRModule, None]:
- """The entry point of the integration workflow. The compilation process of the high-level
- IR should call this method for task extraction and for feedback hints
-
- Basically, this method is equivalent to:
-
- .. code-block:: python
-
- def query_inside_with_scope(task_name, mod, dispatched):
- ctx = MetaScheduleContext.current()
- assert ctx is not None
- mod = ctx.query(task_name, mod, target, dispatched)
-
- Parameters
- ----------
- task_name : str
- The name of the task
- mod : IRModule
- The high-level IR
- target: Target
- Target
- dispatched : Optional[List[IRModule]]
- A list of low-level IRs that the high-level IR could potentially dispatch to
-
- Returns
- -------
- result : IRModule or None
- Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
- more general future use. None is returned if there is no feedback hint.
- """
- return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member
- task_name,
- mod,
- target,
- dispatched,
- )
-
- def __enter__(self) -> "MetaScheduleContext":
- """Entering the scope of the context manager"""
- _ffi_api.MetaScheduleContextEnterScope(self) # type: ignore # pylint: disable=no-member
- return self
-
- def __exit__(self, ptype, value, trace) -> None:
- """Exiting the scope of the context manager"""
- _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member
-
-
-@register_object("meta_schedule.ApplyHistoryBest")
-class ApplyHistoryBest(MetaScheduleContext):
- """An integration context that allows application of historically best record from database"""
-
- database: Database
- """ The database to be queried from"""
-
- def __init__(self, database) -> None:
- self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
-
-
-def extract_task_from_relay(
- mod: Union[IRModule, RelayFunc],
- target: Target,
- params: Optional[Dict[str, NDArray]] = None,
- *,
- opt_level: int = 3,
- pass_config: Optional[Dict[str, Any]] = None,
- disabled_pass: Optional[List[str]] = None,
-) -> List[ExtractedTask]:
- """Extract tuning tasks from a relay program.
-
- Parameters
- ----------
- mod : Union[tvm.IRModule, tvm.relay.Function]
- The module or function to tune
- target : tvm.target.Target
- The compilation target
- params : Optional[Dict[str, tvm.runtime.NDArray]]
- The associated parameters of the program
- opt_level : int
- The optimization level of the compiler
- pass_config : Optional[Dict[str, Any]]
- The pass config of the compiler
- disabled_pass : Optional[List[str]]
- The list of disabled passes of the compiler
-
- Returns
- -------
- tasks: List[ExtractedTask]
- The tasks extracted from this network
- """
-
- extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
- assert extract_task_func
-
- target = Target(target) if isinstance(target, str) else target
-
- relay_params = {}
- for name, param in params.items():
- if isinstance(param, np.ndarray):
- param = nd.array(param)
- relay_params[name] = param
-
- if disabled_pass is None:
- disabled_pass = []
- if pass_config is None:
- pass_config = {"relay.backend.use_meta_schedule": True}
-
- if isinstance(mod, RelayFunc):
- mod = IRModule.from_expr(mod)
- if not isinstance(target, Target):
- target = Target(target)
-
- with autotvm_silencer(), target, transform.PassContext(
- opt_level=opt_level,
- config=pass_config,
- disabled_pass=disabled_pass,
- ):
- return list(extract_task_func(mod, target, relay_params))
diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py
new file mode 100644
index 0000000000..4478ffc76b
--- /dev/null
+++ b/python/tvm/meta_schedule/relay_integration.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""MetaSchedule-Relay integration"""
+from typing import Any, Dict, List, Optional
+
+import numpy as np # type: ignore
+from tvm import nd
+from tvm._ffi import get_global_func
+from tvm.ir import IRModule, transform
+from tvm.runtime import NDArray
+from tvm.target import Target
+
+from .extracted_task import ExtractedTask
+from .utils import autotvm_silencer
+
+
+def extract_task_from_relay(
+ mod: IRModule,
+ target: Target,
+ params: Optional[Dict[str, NDArray]] = None,
+ *,
+ opt_level: int = 3,
+ pass_config: Optional[Dict[str, Any]] = None,
+ disabled_pass: Optional[List[str]] = None,
+) -> List[ExtractedTask]:
+ """Extract tuning tasks from a relay program.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The module or function to tune
+ target : tvm.target.Target
+ The compilation target
+ params : Optional[Dict[str, tvm.runtime.NDArray]]
+ The associated parameters of the program
+ opt_level : int
+ The optimization level of the compiler
+ pass_config : Optional[Dict[str, Any]]
+ The pass config of the compiler
+ disabled_pass : Optional[List[str]]
+ The list of disabled passes of the compiler
+
+ Returns
+ -------
+ tasks: List[ExtractedTask]
+ The tasks extracted from this network
+ """
+ # pylint: disable=import-outside-toplevel
+ from tvm.relay import Function as RelayFunc
+
+ # pylint: enable=import-outside-toplevel
+
+ extract_task_func = get_global_func(
+ "relay.backend.MetaScheduleExtractTask",
+ allow_missing=False,
+ )
+
+ if isinstance(mod, RelayFunc):
+ mod = IRModule.from_expr(mod)
+ if not isinstance(target, Target):
+ target = Target(target)
+ if disabled_pass is None:
+ disabled_pass = []
+ if pass_config is None:
+ pass_config = {"relay.backend.use_meta_schedule": True}
+ relay_params = {}
+ for name, param in params.items():
+ if isinstance(param, np.ndarray):
+ param = nd.array(param)
+ relay_params[name] = param
+
+ with autotvm_silencer(), target, transform.PassContext(
+ opt_level=opt_level,
+ config=pass_config,
+ disabled_pass=disabled_pass,
+ ):
+ return list(extract_task_func(mod, target, relay_params))
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py
index 83a70abb7f..2dbd290a28 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -26,7 +26,7 @@ import tvm
import tvm.relay.testing
from tvm import relay
from tvm.ir import IRModule
-from tvm.meta_schedule.integration import ExtractedTask, extract_task_from_relay
+from tvm.meta_schedule import ExtractedTask, extract_task_from_relay
from tvm.runtime import NDArray, load_param_dict, save_param_dict
from tvm.target import Target
diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
index 5859412ebb..0973c9b91b 100644
--- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
+++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
@@ -24,7 +24,6 @@ import numpy as np # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.ir.transform import PassContext
-from tvm.meta_schedule.integration import extract_task_from_relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.relay import build as relay_build
@@ -107,7 +106,7 @@ def tune_each_task(
work_dir,
params,
):
- extracted_tasks = extract_task_from_relay(mod, target, params)
+ extracted_tasks = ms.extract_task_from_relay(mod, target, params)
database = ms.database.JSONDatabase(
path_workload=os.path.join(work_dir, "default_database_workload.json"),
path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"),
@@ -139,7 +138,7 @@ def tune_each_task(
)
# pylint: enable=protected-access
task_scheduler.tune()
- with target, ms.integration.ApplyHistoryBest(database):
+ with target, ms.ApplyHistoryBest(database):
with PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py
index e22677a3b9..a832dfc6bc 100644
--- a/python/tvm/meta_schedule/testing/utils.py
+++ b/python/tvm/meta_schedule/testing/utils.py
@@ -14,31 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Testing utilitiy functions in meta schedule"""
+"""Testing utility functions in meta schedule"""
import random
-from typing import List, Optional, Callable, Dict, Union
+from typing import Callable, Dict, List, Optional, Union
import tvm
-from tvm.relay import Function as RelayFunc
-from tvm.tir import Schedule
-from tvm.target import Target
-from tvm.runtime import NDArray
+from tvm.ir import IRModule
from tvm.meta_schedule import TuneContext # pylint: disable=unused-import
-from tvm.meta_schedule.utils import derived_object
+from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
+from tvm.meta_schedule.extracted_task import ExtractedTask
from tvm.meta_schedule.mutator.mutator import PyMutator
-from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
-from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
+from tvm.meta_schedule.relay_integration import extract_task_from_relay
from tvm.meta_schedule.runner import (
+ PyRunner,
+ PyRunnerFuture,
+ RunnerFuture,
RunnerInput,
RunnerResult,
- RunnerFuture,
- PyRunnerFuture,
- PyRunner,
)
-from tvm.meta_schedule.tune import Parse, extract_task_from_relay
-from tvm.meta_schedule.integration import ExtractedTask
-
-from tvm.ir import IRModule
+from tvm.meta_schedule.tune import Parse
+from tvm.meta_schedule.utils import derived_object
+from tvm.relay import Function as RelayFunc
+from tvm.runtime import NDArray
+from tvm.target import Target
+from tvm.tir import Schedule
from tvm.tir.schedule import Trace
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index 86157e0fb3..31130f67af 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -23,18 +23,17 @@ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from tvm._ffi.registry import register_func
from tvm.ir import IRModule, structural_hash
from tvm.ir.transform import PassContext
-from tvm.relay import Function as RelayFunc
-from tvm.relay import build as relay_build
from tvm.runtime import Module, NDArray
from tvm.target import Target
from tvm.te import Tensor, create_prim_func
from tvm.tir import PrimFunc, Schedule
+from .apply_history_best import ApplyHistoryBest
from .builder import Builder, LocalBuilder
from .cost_model import CostModel, XGBModel
from .database import Database, JSONDatabase, TuningRecord
+from .extracted_task import ExtractedTask
from .feature_extractor import PerStoreFeature
-from .integration import ApplyHistoryBest, ExtractedTask, extract_task_from_relay
from .measure_callback import MeasureCallback
from .mutator import Mutator
from .postproc import Postproc
@@ -822,7 +821,7 @@ def tune_extracted_tasks(
def tune_relay(
- mod: Union[RelayFunc, IRModule],
+ mod: IRModule,
target: Union[str, Target],
config: SearchStrategyConfig,
work_dir: str,
@@ -844,7 +843,7 @@ def tune_relay(
Parameters
----------
- mod : Union[RelayFunc, IRModule]
+ mod : IRModule
The module to tune.
target : Union[str, Target]
The target to tune for.
@@ -874,6 +873,12 @@ def tune_relay(
lib : Module
The built runtime module for the given relay workload.
"""
+ # pylint: disable=import-outside-toplevel
+ from tvm.relay import build as relay_build
+
+ from .relay_integration import extract_task_from_relay
+
+ # pylint: enable=import-outside-toplevel
logger.info("Working directory: %s", work_dir)
# pylint: disable=protected-access
diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/apply_history_best.cc
similarity index 58%
rename from src/meta_schedule/integration.cc
rename to src/meta_schedule/apply_history_best.cc
index 35c3baf237..41714cf7b0 100644
--- a/src/meta_schedule/integration.cc
+++ b/src/meta_schedule/apply_history_best.cc
@@ -16,17 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/meta_schedule/integration.h>
-#include <tvm/relay/function.h>
-#include <tvm/tir/function.h>
-
#include "./utils.h"
-#include "tvm/runtime/container/optional.h"
namespace tvm {
namespace meta_schedule {
/**************** Utility functions ****************/
+
template <class FunctionType, class RetType, class Callback>
Optional<RetType> GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) {
if (mod->functions.size() != 1) {
@@ -59,54 +55,36 @@ bool HasOnlyOneFunction(const IRModule& mod) {
return GetOnlyOneFunction<FunctionType>(mod).defined();
}
-/**************** ExtractedTask ****************/
-
-ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
- Array<IRModule> dispatched, int weight) {
- ObjectPtr<ExtractedTaskNode> n = make_object<ExtractedTaskNode>();
- n->task_name = task_name;
- n->mod = mod;
- n->target = target;
- n->dispatched = dispatched;
- n->weight = weight;
- data_ = n;
-}
+/**************** Context Manager ****************/
-/**************** MetaScheduleContext ****************/
+class ApplyHistoryBestInternal {
+ public:
+ static void EnterScope(ApplyHistoryBest ctx) { ctx.EnterWithScope(); }
+ static void ExitScope(ApplyHistoryBest ctx) { ctx.ExitWithScope(); }
+};
-struct MetaScheduleContextThreadLocalEntry {
- Optional<MetaScheduleContext> ctx;
+struct ApplyHistoryBestThreadLocalEntry {
+ Optional<ApplyHistoryBest> ctx;
};
-using MetaScheduleContextThreadLocalStore =
- dmlc::ThreadLocalStore<MetaScheduleContextThreadLocalEntry>;
+using ApplyHistoryBestThreadLocalStore = dmlc::ThreadLocalStore<ApplyHistoryBestThreadLocalEntry>;
-Optional<MetaScheduleContext> MetaScheduleContext::Current() {
- return MetaScheduleContextThreadLocalStore::Get()->ctx;
+Optional<ApplyHistoryBest> ApplyHistoryBest::Current() {
+ return ApplyHistoryBestThreadLocalStore::Get()->ctx;
}
-void MetaScheduleContext::EnterWithScope() {
- Optional<MetaScheduleContext>& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx;
- CHECK(!ctx.defined())
- << "ValueError: Nested MetaScheduleContext context managers are not allowed";
+void ApplyHistoryBest::EnterWithScope() {
+ Optional<ApplyHistoryBest>& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx;
+ CHECK(!ctx.defined()) << "ValueError: Nested ApplyHistoryBest context managers are not allowed";
ctx = *this;
}
-void MetaScheduleContext::ExitWithScope() {
- Optional<MetaScheduleContext>& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx;
+void ApplyHistoryBest::ExitWithScope() {
+ Optional<ApplyHistoryBest>& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx;
ICHECK(ctx.defined());
ctx = NullOpt;
}
-Optional<IRModule> MetaScheduleContext::QueryInsideWithScope(runtime::String task_name,
- IRModule mod, Target target,
- Optional<Array<IRModule>> dispatched) {
- if (Optional<MetaScheduleContext> ctx = MetaScheduleContext::Current()) {
- return ctx.value()->Query(task_name, mod, target, dispatched);
- }
- return NullOpt;
-}
-
/**************** ApplyHistoryBest ****************/
ApplyHistoryBest::ApplyHistoryBest(Database database) {
@@ -149,37 +127,19 @@ Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModu
return NullOpt;
}
-/**************** FFI ****************/
-
-class MetaScheduleContextInternal {
- public:
- static void EnterScope(MetaScheduleContext ctx) { ctx.EnterWithScope(); }
- static void ExitScope(MetaScheduleContext ctx) { ctx.ExitWithScope(); }
-};
-
-TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
-TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode);
TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);
-
-TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
- .set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
- int weight) -> ExtractedTask {
- return ExtractedTask(task_name, mod, target, dispatched, weight);
- });
-TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope")
- .set_body_typed(MetaScheduleContextInternal::EnterScope);
-TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextExitScope")
- .set_body_typed(MetaScheduleContextInternal::ExitScope);
-TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextCurrent")
- .set_body_typed(MetaScheduleContext::Current);
-TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope")
- .set_body_typed(MetaScheduleContext::QueryInsideWithScope);
-TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery")
- .set_body_method<MetaScheduleContext>(&MetaScheduleContextNode::Query);
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
.set_body_typed([](Database database) -> ApplyHistoryBest {
return ApplyHistoryBest(database);
});
+TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope")
+ .set_body_typed(ApplyHistoryBestInternal::EnterScope);
+TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestExitScope")
+ .set_body_typed(ApplyHistoryBestInternal::ExitScope);
+TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestCurrent")
+ .set_body_typed(ApplyHistoryBest::Current);
+TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestQuery")
+ .set_body_method<ApplyHistoryBest>(&ApplyHistoryBestNode::Query);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc
new file mode 100644
index 0000000000..b1044fc87d
--- /dev/null
+++ b/src/meta_schedule/extracted_task.cc
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/meta_schedule/extracted_task.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
+ Array<IRModule> dispatched, int weight) {
+ ObjectPtr<ExtractedTaskNode> n = make_object<ExtractedTaskNode>();
+ n->task_name = task_name;
+ n->mod = mod;
+ n->target = target;
+ n->dispatched = dispatched;
+ n->weight = weight;
+ data_ = n;
+}
+
+TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
+ .set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
+ int weight) -> ExtractedTask {
+ return ExtractedTask(task_name, mod, target, dispatched, weight);
+ });
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 2ee18a8668..45a04958ad 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -21,6 +21,7 @@
#include <dmlc/memory_io.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/meta_schedule/apply_history_best.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/cost_model.h>
diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc
index a787f19150..0895fd42a3 100644
--- a/src/relay/backend/task_extraction.cc
+++ b/src/relay/backend/task_extraction.cc
@@ -17,16 +17,15 @@
* under the License.
*/
-#include <tvm/meta_schedule/integration.h>
+#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/function.h>
#include <tvm/target/target.h>
#include "../../te/operation/create_primfunc.h"
-#include "te_compiler_cache.h"
-#include "tvm/runtime/ndarray.h"
-#include "utils.h"
+#include "./te_compiler_cache.h"
+#include "./utils.h"
namespace tvm {
namespace relay {
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index e0e7277676..a8edeff862 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -21,7 +21,7 @@
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
-#include <tvm/meta_schedule/integration.h>
+#include <tvm/meta_schedule/apply_history_best.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
@@ -302,7 +302,13 @@ class ScheduleBuilder : public ExprVisitor {
explicit ScheduleBuilder(Target target) : target_(target) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
- use_meta_scheduler_ = backend::IsMetaScheduleEnabled();
+ if (backend::IsMetaScheduleEnabled()) {
+ meta_schedule_ctx_ = meta_schedule::ApplyHistoryBest::Current();
+ CHECK(meta_schedule_ctx_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay "
+ "build, but no ApplyHistoryBest context is provided. ";
+ } else {
+ meta_schedule_ctx_ = NullOpt;
+ }
}
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
@@ -340,12 +346,11 @@ class ScheduleBuilder : public ExprVisitor {
schedule = Downcast<te::Schedule>(obj);
}
}
- if (use_meta_scheduler_) {
+ if (meta_schedule_ctx_) {
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}});
- Optional<IRModule> scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope(
- prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod});
- if (scheduled_mod) {
+ if (Optional<IRModule> scheduled_mod = meta_schedule_ctx_.value()->Query(
+ prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod})) {
ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1);
prim_func = Downcast<tir::PrimFunc>(scheduled_mod.value()->functions[prim_fn_var]);
}
@@ -381,7 +386,7 @@ class ScheduleBuilder : public ExprVisitor {
}
int op_pattern = fpattern[op];
- if (!use_auto_scheduler_ && !use_meta_scheduler_ && op_pattern >= kCommReduce) {
+ if (!use_auto_scheduler_ && !meta_schedule_ctx_.defined() && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
@@ -399,7 +404,7 @@ class ScheduleBuilder : public ExprVisitor {
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
bool use_auto_scheduler_;
- bool use_meta_scheduler_;
+ Optional<meta_schedule::ApplyHistoryBest> meta_schedule_ctx_;
};
/*!
diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py
index 1bbaf35ad2..b17d6ffc60 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import sys
-from typing import List
import numpy as np
import pytest
@@ -23,18 +22,13 @@ import tvm
import tvm.testing
from tvm import meta_schedule as ms
from tvm import relay
-from tvm.ir.module import IRModule
-from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
-from tvm.meta_schedule.integration import (
- ApplyHistoryBest,
- ExtractedTask,
- MetaScheduleContext,
-)
+from tvm.meta_schedule import ApplyHistoryBest
+from tvm.meta_schedule.database import TuningRecord
+from tvm.meta_schedule.relay_integration import extract_task_from_relay
from tvm.meta_schedule.testing import DummyDatabase
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
-from tvm.meta_schedule.tune import Parse, extract_task_from_relay
-from tvm.meta_schedule.utils import derived_object
+from tvm.meta_schedule.tune import Parse
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
@@ -68,14 +62,14 @@ def _has_torch():
requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed")
-def test_meta_schedule_integration_no_current():
- assert MetaScheduleContext.current() is None
+def test_meta_schedule_apply_history_best_no_current():
+ assert ApplyHistoryBest.current() is None
@requires_torch
def test_meta_schedule_integration_extract_from_resnet():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
- extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
+ extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params)
expected_task_names = [
"fused_" + s
for s in [
@@ -189,7 +183,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
),
}
mod, params, _ = get_network(name="bert_base", input_shape=[1, 64])
- extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
+ extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params)
assert len(extracted_tasks) == len(expected)
for t in extracted_tasks:
prim_func = None
diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py
index 78d0ddeda3..0b8af9c145 100644
--- a/tests/python/unittest/test_meta_schedule_multi_anchor.py
+++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py
@@ -15,12 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
-
import tvm
import tvm.testing
from tvm import relay
+from tvm.meta_schedule import ApplyHistoryBest
from tvm.meta_schedule.testing import apply_fixed_schedules
-from tvm.meta_schedule.integration import ApplyHistoryBest
def get_dense_dense(data_shape, weight_shape):
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py
index 76cd82920c..af25d2a6f3 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -27,16 +27,12 @@ from tvm import relay, tir
from tvm._ffi import register_func
from tvm.contrib import graph_executor
from tvm.ir import IRModule
-from tvm.meta_schedule import ReplayTraceConfig
+from tvm.meta_schedule import ApplyHistoryBest, ReplayTraceConfig
from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload
-from tvm.meta_schedule.integration import ApplyHistoryBest
-from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.relay_integration import extract_task_from_relay
from tvm.meta_schedule.testing import apply_fixed_schedules
-from tvm.meta_schedule.tune import (
- extract_task_from_relay,
- tune_extracted_tasks,
- tune_relay,
-)
+from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.tune import tune_extracted_tasks, tune_relay
from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
from tvm.target.target import Target
@@ -528,13 +524,13 @@ def manual_tir_common(do_tune=False):
):
"""
The log should say
- meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_expand_dims
- meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast
- meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast_1
- meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul
+ Warning: Cannot find workload: tvmgen_default_fused_expand_dims
+ Warning: Cannot find workload: tvmgen_default_fused_cast
+ Warning: Cannot find workload: tvmgen_default_fused_cast_1
+ Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul
- This means batch matmul and others are scheduled by TE, and dense (the one not warned) is found in the
- meta schedule tuning database during ApplyHistoryBest
+ This means batch matmul and others are scheduled by TE, and dense (the one not warned)
+ is found in the meta schedule tuning database during ApplyHistoryBest
"""
lib = relay.build(relay_mod, target=target, params=params)